Compare commits
67 Commits
55e6b0583d
...
dev
| Author | SHA1 | Date | |
|---|---|---|---|
| d86888704f | |||
| de6b6e20a5 | |||
| 1e8a5e08ed | |||
| 218ebbf32f | |||
| c49e7f4b22 | |||
| 9588c8c562 | |||
| 1948ac1284 | |||
| 3f92b7d963 | |||
| 5553e61dbf | |||
| 7f987737f9 | |||
| 5182f86133 | |||
| a50099ad74 | |||
| 20ba5523ee | |||
| 0b2b3701dc | |||
| 438b05b8a3 | |||
| e2a31b192f | |||
| b827d3d047 | |||
| 9c0cf274a3 | |||
| 85ae319690 | |||
| 449f133a1f | |||
| 2f6b03ef65 | |||
| d4030dc598 | |||
| 3271697f6b | |||
| cbfef5a5df | |||
| 52efd5f341 | |||
| 200cdbc4bd | |||
| 8525819ab4 | |||
| bcd52d526c | |||
| 7effade1d3 | |||
| dc0fee2ee3 | |||
| ea04a25ed6 | |||
| 282dcdce88 | |||
| b49f58bc16 | |||
| cdc425ae93 | |||
| 3525cb3949 | |||
| 9d85420bf6 | |||
| 641c95131f | |||
| 708c626176 | |||
| 5210e196f2 | |||
| 30c375b6c5 | |||
| baf49b1e69 | |||
| 96e0436d43 | |||
| 498e6e61b6 | |||
| 99064b6c41 | |||
| ee58b0ac32 | |||
| 990f93d467 | |||
| 44a00619b5 | |||
| 6923ee439f | |||
| c997b19b53 | |||
| c9daf68fea | |||
| ba9d083088 | |||
| 825dfc0722 | |||
| 3e4eacd1d3 | |||
| 23253219a3 | |||
| cc2b85a86d | |||
| 58dd6f3efa | |||
| c81d0f1593 | |||
| ae0dd3fc51 | |||
| 80dffa9f41 | |||
| ab0ae4fe04 | |||
| d31e068277 | |||
| 690f5c7056 | |||
| 0da8a3f193 | |||
| 15f81d9728 | |||
| b80db89391 | |||
| f413a63c5a | |||
| 33ad3797a1 |
34
.github/workflows/macos-check.yml
vendored
Normal file
34
.github/workflows/macos-check.yml
vendored
Normal file
@@ -0,0 +1,34 @@
|
||||
name: macos-check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: cargo check (macOS)
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Cargo check
|
||||
run: cargo check --workspace --all-features
|
||||
@@ -9,6 +9,7 @@ repos:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
args: ['--allow-multiple-documents']
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
|
||||
@@ -1,3 +1,61 @@
|
||||
---
|
||||
kind: pipeline
|
||||
name: pr-checks
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
steps:
|
||||
- name: fmt-clippy-test
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add rustfmt clippy
|
||||
- cargo fmt --all -- --check
|
||||
- cargo clippy --workspace --all-features -- -D warnings
|
||||
- cargo test --workspace --all-features
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: security-audit
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- cron
|
||||
branch:
|
||||
- dev
|
||||
cron: weekly-security
|
||||
|
||||
steps:
|
||||
- name: cargo-audit
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- cargo install cargo-audit --locked
|
||||
- cargo audit
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release-tests
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
steps:
|
||||
- name: workspace-tests
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
@@ -5,6 +63,9 @@ when:
|
||||
variables:
|
||||
- &rust_image 'rust:1.83'
|
||||
|
||||
depends_on:
|
||||
- release-tests
|
||||
|
||||
matrix:
|
||||
include:
|
||||
# Linux
|
||||
@@ -39,14 +100,6 @@ matrix:
|
||||
EXT: ".exe"
|
||||
|
||||
steps:
|
||||
- name: tests
|
||||
image: *rust_image
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
- name: build
|
||||
image: *rust_image
|
||||
commands:
|
||||
@@ -124,6 +177,11 @@ steps:
|
||||
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
||||
fi
|
||||
|
||||
- name: release-notes
|
||||
image: *rust_image
|
||||
commands:
|
||||
- scripts/release-notes.sh "${CI_COMMIT_TAG}" release-notes.md
|
||||
|
||||
- name: release
|
||||
image: plugins/gitea-release
|
||||
settings:
|
||||
@@ -136,4 +194,4 @@ steps:
|
||||
- ${ARTIFACT}.zip
|
||||
- ${ARTIFACT}.zip.sha256
|
||||
title: Release ${CI_COMMIT_TAG}
|
||||
note: "Release ${CI_COMMIT_TAG}"
|
||||
note_file: release-notes.md
|
||||
|
||||
@@ -11,6 +11,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
||||
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
||||
- Module-level documentation for `owlen-tui`.
|
||||
- Provider integration tests (`crates/owlen-providers/tests`) covering registration, routing, and health status handling for the new `ProviderManager`.
|
||||
- TUI message and generation tests that exercise the non-blocking event loop, background worker, and message dispatch.
|
||||
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
||||
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
||||
- `owlen config doctor`, `owlen config path`, and `owlen upgrade` CLI commands to automate migrations and surface manual update steps.
|
||||
@@ -26,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- Input panel respects a new `ui.input_max_rows` setting so long prompts expand predictably before scrolling kicks in.
|
||||
- Command palette offers fuzzy `:model` filtering and `:provider` completions for fast switching.
|
||||
- Message rendering caches wrapped lines and throttles streaming redraws to keep the TUI responsive on long sessions.
|
||||
- Model picker badges now inspect provider capabilities so vision/audio/thinking models surface the correct icons even when descriptions are sparse.
|
||||
- Chat history honors `ui.scrollback_lines`, trimming older rows to keep the TUI responsive and surfacing a "↓ New messages" badge whenever updates land off-screen.
|
||||
|
||||
### Changed
|
||||
@@ -38,9 +41,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
- `config.toml` now carries a schema version (`1.2.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||
- Model selector navigation (Tab/Shift-Tab) now switches between local and cloud tabs while preserving selection state.
|
||||
- Header displays the active model together with its provider (e.g., `Model (Provider)`), improving clarity when swapping backends.
|
||||
- Documentation refreshed to cover the message handler architecture, the background health worker, multi-provider configuration, and the new provider onboarding checklist.
|
||||
|
||||
---
|
||||
|
||||
## [0.1.11] - 2025-10-18
|
||||
|
||||
### Changed
|
||||
- Bump workspace packages and distribution metadata to version `0.1.11`.
|
||||
|
||||
## [0.1.10] - 2025-10-03
|
||||
|
||||
### Added
|
||||
|
||||
@@ -10,6 +10,10 @@ This project and everyone participating in it is governed by the [Owlen Code of
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
### Repository map
|
||||
|
||||
Need a quick orientation before diving in? Start with the curated [repo map](docs/repo-map.md) for a two-level directory overview. If you move folders around, regenerate it with `scripts/gen-repo-map.sh`.
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
||||
|
||||
15
Cargo.toml
15
Cargo.toml
@@ -4,16 +4,19 @@ members = [
|
||||
"crates/owlen-core",
|
||||
"crates/owlen-tui",
|
||||
"crates/owlen-cli",
|
||||
"crates/owlen-mcp-server",
|
||||
"crates/owlen-mcp-llm-server",
|
||||
"crates/owlen-mcp-client",
|
||||
"crates/owlen-mcp-code-server",
|
||||
"crates/owlen-mcp-prompt-server",
|
||||
"crates/owlen-providers",
|
||||
"crates/mcp/server",
|
||||
"crates/mcp/llm-server",
|
||||
"crates/mcp/client",
|
||||
"crates/mcp/code-server",
|
||||
"crates/mcp/prompt-server",
|
||||
"crates/owlen-markdown",
|
||||
"xtask",
|
||||
]
|
||||
exclude = []
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.9"
|
||||
version = "0.1.11"
|
||||
edition = "2024"
|
||||
authors = ["Owlibou"]
|
||||
license = "AGPL-3.0"
|
||||
|
||||
2
PKGBUILD
2
PKGBUILD
@@ -1,6 +1,6 @@
|
||||
# Maintainer: vikingowl <christian@nachtigall.dev>
|
||||
pkgname=owlen
|
||||
pkgver=0.1.9
|
||||
pkgver=0.1.11
|
||||
pkgrel=1
|
||||
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
||||
arch=('x86_64')
|
||||
|
||||
58
README.md
58
README.md
@@ -3,16 +3,17 @@
|
||||
> Terminal-native assistant for running local language models with a comfortable TUI.
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## What Is OWLEN?
|
||||
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local large
|
||||
language models. It provides a responsive chat workflow that runs against
|
||||
[Ollama](https://ollama.com/) with a focus on developer productivity, vim-style navigation,
|
||||
and seamless session management—all without leaving your terminal.
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud
|
||||
language models. It provides a responsive chat workflow that now routes through a
|
||||
multi-provider manager—handling local Ollama, Ollama Cloud, and future MCP-backed providers—
|
||||
with a focus on developer productivity, vim-style navigation, and seamless session
|
||||
management—all without leaving your terminal.
|
||||
|
||||
## Alpha Status
|
||||
|
||||
@@ -32,7 +33,9 @@ The OWLEN interface features a clean, multi-panel layout with vim-inspired navig
|
||||
- **Session Management**: Save, load, and manage conversations.
|
||||
- **Code Side Panel**: Switch to code mode (`:mode code`) and open files inline with `:open <path>` for LLM-assisted coding.
|
||||
- **Theming System**: 10 built-in themes and support for custom themes.
|
||||
- **Modular Architecture**: Extensible provider system (Ollama today, additional providers on the roadmap).
|
||||
- **Modular Architecture**: Extensible provider system orchestrated by the new `ProviderManager`, ready for additional MCP-backed providers.
|
||||
- **Dual-Source Model Picker**: Merge local and cloud catalogues with real-time availability badges powered by the background health worker.
|
||||
- **Non-Blocking UI Loop**: Asynchronous generation tasks and provider health checks run off-thread, keeping the TUI responsive even while streaming long replies.
|
||||
- **Guided Setup**: `owlen config doctor` upgrades legacy configs and verifies your environment in seconds.
|
||||
|
||||
## Security & Privacy
|
||||
@@ -54,20 +57,28 @@ Owlen is designed to keep data local by default while still allowing controlled
|
||||
|
||||
### Installation
|
||||
|
||||
#### Linux & macOS
|
||||
The recommended way to install on Linux and macOS is to clone the repository and install using `cargo`.
|
||||
Pick the option that matches your platform and appetite for source builds:
|
||||
|
||||
| Platform | Package / Command | Notes |
|
||||
| --- | --- | --- |
|
||||
| Arch Linux | `yay -S owlen-git` | Builds from the latest `dev` branch via AUR. |
|
||||
| Other Linux | `cargo install --path crates/owlen-cli --locked --force` | Requires Rust 1.75+ and a running Ollama daemon. |
|
||||
| macOS | `cargo install --path crates/owlen-cli --locked --force` | macOS 12+ tested. Install Ollama separately (`brew install ollama`). The binary links against the system OpenSSL – ensure Command Line Tools are installed. |
|
||||
| Windows (experimental) | `cargo install --path crates/owlen-cli --locked --force` | Enable the GNU toolchain (`rustup target add x86_64-pc-windows-gnu`) and install Ollama for Windows preview builds. Some optional tools (e.g., Docker-based code execution) are currently disabled. |
|
||||
|
||||
If you prefer containerised builds, use the provided `Dockerfile` as a base image and copy out `target/release/owlen`.
|
||||
|
||||
Run the helper scripts to sanity-check platform coverage:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Owlibou/owlen.git
|
||||
cd owlen
|
||||
cargo install --path crates/owlen-cli
|
||||
# Windows compatibility smoke test (GNU toolchain)
|
||||
scripts/check-windows.sh
|
||||
|
||||
# Reproduce CI packaging locally (choose a target from .woodpecker.yml)
|
||||
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||
```
|
||||
**Note for macOS**: While this method works, official binary releases for macOS are planned for the future.
|
||||
|
||||
#### Windows
|
||||
The Windows build has not been thoroughly tested yet. Installation is possible via the same `cargo install` method, but it is considered experimental at this time.
|
||||
|
||||
From Unix hosts you can run `scripts/check-windows.sh` to ensure the code base still compiles for Windows (`rustup` will install the required target automatically).
|
||||
> **Tip (macOS):** On the first launch macOS Gatekeeper may quarantine the binary. Clear the attribute (`xattr -d com.apple.quarantine $(which owlen)`) or build from source locally to avoid notarisation prompts.
|
||||
|
||||
### Running OWLEN
|
||||
|
||||
@@ -90,8 +101,16 @@ OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode)
|
||||
|
||||
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:save`, `:theme`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:w`, `:session save`, `:theme`.
|
||||
- **Quick Exit**: Press `Ctrl+C` twice in Normal mode to quit quickly (first press still cancels active generations).
|
||||
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||
|
||||
Model discovery commands worth remembering:
|
||||
|
||||
- `:models --local` or `:models --cloud` jump directly to the corresponding section in the picker.
|
||||
- `:cloud setup [--force-cloud-base-url]` stores your cloud API key without clobbering an existing local base URL (unless you opt in with the flag).
|
||||
When a catalogue is unreachable, Owlen now tags the picker with `Local unavailable` / `Cloud unavailable` so you can recover without guessing.
|
||||
|
||||
## Documentation
|
||||
|
||||
@@ -101,7 +120,10 @@ For more detailed information, please refer to the following documents:
|
||||
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
||||
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
||||
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: A guide for adding new providers.
|
||||
- **[docs/repo-map.md](docs/repo-map.md)**: Snapshot of the workspace layout and key crates.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: Trait-level details for implementing providers.
|
||||
- **[docs/adding-providers.md](docs/adding-providers.md)**: Step-by-step checklist for wiring a provider into the multi-provider architecture and test suite.
|
||||
- **Experimental providers staging area**: [crates/providers/experimental/README.md](crates/providers/experimental/README.md) records the placeholder crates (OpenAI, Anthropic, Gemini) and their current status.
|
||||
- **[docs/platform-support.md](docs/platform-support.md)**: Current OS support matrix and cross-check instructions.
|
||||
|
||||
## Configuration
|
||||
|
||||
29
config.toml
Normal file
29
config.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[general]
|
||||
default_provider = "ollama_local"
|
||||
default_model = "llama3.2:latest"
|
||||
|
||||
[privacy]
|
||||
encrypt_local_data = true
|
||||
|
||||
[providers.ollama_local]
|
||||
enabled = true
|
||||
provider_type = "ollama"
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
[providers.ollama_cloud]
|
||||
enabled = false
|
||||
provider_type = "ollama_cloud"
|
||||
base_url = "https://ollama.com"
|
||||
api_key_env = "OLLAMA_CLOUD_API_KEY"
|
||||
|
||||
[providers.openai]
|
||||
enabled = false
|
||||
provider_type = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
|
||||
[providers.anthropic]
|
||||
enabled = false
|
||||
provider_type = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
@@ -6,7 +6,7 @@ description = "Dedicated MCP client library for Owlen, exposing remote MCP serve
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@@ -5,6 +5,7 @@
|
||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||
//! core crate internals.
|
||||
|
||||
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
|
||||
@@ -6,7 +6,7 @@ description = "MCP server exposing safe code execution tools for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
@@ -4,7 +4,7 @@ version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
@@ -126,7 +126,7 @@ fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama-cloud" => {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
.map_err(|e| {
|
||||
RpcError::internal_error(format!(
|
||||
@@ -153,10 +153,12 @@ fn create_provider() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
}
|
||||
|
||||
fn canonical_provider_name(name: &str) -> String {
|
||||
if name.eq_ignore_ascii_case("ollama-cloud") {
|
||||
"ollama".to_string()
|
||||
} else {
|
||||
name.to_string()
|
||||
let normalized = name.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama_local".to_string(),
|
||||
"ollama" | "ollama_local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" => "ollama_cloud".to_string(),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ description = "MCP server that renders prompt templates (YAML) for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
@@ -9,4 +9,4 @@ serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
path-clean = "1.0"
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
@@ -17,6 +17,11 @@ name = "owlen"
|
||||
path = "src/main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-code"
|
||||
path = "src/code_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-agent"
|
||||
path = "src/agent_main.rs"
|
||||
@@ -24,6 +29,7 @@ required-features = ["chat-client"]
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-providers = { path = "../owlen-providers" }
|
||||
# Optional TUI dependency, enabled by the "chat-client" feature.
|
||||
owlen-tui = { path = "../owlen-tui", optional = true }
|
||||
log = { workspace = true }
|
||||
|
||||
326
crates/owlen-cli/src/bootstrap.rs
Normal file
326
crates/owlen-cli/src/bootstrap.rs
Normal file
@@ -0,0 +1,326 @@
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures::stream;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
provider::ProviderManager,
|
||||
providers::OllamaProvider,
|
||||
session::{ControllerEvent, SessionController},
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
};
|
||||
use owlen_tui::{
|
||||
ChatApp, SessionEvent,
|
||||
app::App as RuntimeApp,
|
||||
config,
|
||||
tui_controller::{TuiController, TuiRequest},
|
||||
ui,
|
||||
};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::commands::cloud::{load_runtime_credentials, set_env_var};
|
||||
|
||||
pub async fn launch(initial_mode: Mode) -> Result<()> {
|
||||
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
let _ = cfg.refresh_mcp_servers(None);
|
||||
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
let provider = build_provider(&cfg)?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.effective_mcp_servers().is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let (controller_event_tx, controller_event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
let controller = SessionController::new(
|
||||
provider,
|
||||
cfg,
|
||||
storage.clone(),
|
||||
tui_controller,
|
||||
false,
|
||||
Some(controller_event_tx),
|
||||
)
|
||||
.await?;
|
||||
let provider_manager = Arc::new(ProviderManager::default());
|
||||
let mut runtime = RuntimeApp::new(provider_manager);
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice.clone() {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut runtime, &mut app, &mut session_rx).await;
|
||||
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => Ok(Arc::new(client) as Arc<dyn Provider>),
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow!(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||
anyhow!("[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"")
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||
Ok(Arc::new(client) as Arc<dyn Provider>)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
runtime: &mut RuntimeApp,
|
||||
app: &mut ChatApp,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
let mut render = |terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut ChatApp|
|
||||
-> Result<()> {
|
||||
terminal.draw(|f| ui::render_chat(f, state))?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
runtime.run(terminal, app, session_rx, &mut render).await?;
|
||||
Ok(())
|
||||
}
|
||||
16
crates/owlen-cli/src/code_main.rs
Normal file
16
crates/owlen-cli/src/code_main.rs
Normal file
@@ -0,0 +1,16 @@
|
||||
//! Owlen CLI entrypoint optimised for code-first workflows.
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::Result;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
bootstrap::launch(Mode::Code).await
|
||||
}
|
||||
@@ -6,14 +6,19 @@ use anyhow::{Context, Result, anyhow, bail};
|
||||
use clap::Subcommand;
|
||||
use owlen_core::LlmProvider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::config::Config;
|
||||
use owlen_core::config::{
|
||||
self as core_config, Config, OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||||
};
|
||||
use owlen_core::credentials::{ApiCredentials, CredentialManager, OLLAMA_CLOUD_CREDENTIAL_ID};
|
||||
use owlen_core::encryption;
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::storage::StorageManager;
|
||||
use serde_json::Value;
|
||||
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = "https://ollama.com";
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const CLOUD_ENDPOINT_KEY: &str = OLLAMA_CLOUD_ENDPOINT_KEY;
|
||||
const CLOUD_PROVIDER_KEY: &str = "ollama_cloud";
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum CloudCommand {
|
||||
@@ -25,26 +30,29 @@ pub enum CloudCommand {
|
||||
/// Override the cloud endpoint (default: https://ollama.com)
|
||||
#[arg(long)]
|
||||
endpoint: Option<String>,
|
||||
/// Provider name to configure (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to configure (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
/// Overwrite the provider base URL with the cloud endpoint
|
||||
#[arg(long)]
|
||||
force_cloud_base_url: bool,
|
||||
},
|
||||
/// Check connectivity to Ollama Cloud
|
||||
Status {
|
||||
/// Provider name to check (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to check (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// List available cloud-hosted models
|
||||
Models {
|
||||
/// Provider name to query (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to query (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// Remove stored Ollama Cloud credentials
|
||||
Logout {
|
||||
/// Provider name to clear (default: ollama)
|
||||
#[arg(long, default_value = "ollama")]
|
||||
/// Provider name to clear (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
@@ -55,19 +63,30 @@ pub async fn run_cloud_command(command: CloudCommand) -> Result<()> {
|
||||
api_key,
|
||||
endpoint,
|
||||
provider,
|
||||
} => setup(provider, api_key, endpoint).await,
|
||||
force_cloud_base_url,
|
||||
} => setup(provider, api_key, endpoint, force_cloud_base_url).await,
|
||||
CloudCommand::Status { provider } => status(provider).await,
|
||||
CloudCommand::Models { provider } => models(provider).await,
|
||||
CloudCommand::Logout { provider } => logout(provider).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup(provider: String, api_key: Option<String>, endpoint: Option<String>) -> Result<()> {
|
||||
async fn setup(
|
||||
provider: String,
|
||||
api_key: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
force_cloud_base_url: bool,
|
||||
) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let endpoint = endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let endpoint =
|
||||
normalize_endpoint(&endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string()));
|
||||
|
||||
ensure_provider_entry(&mut config, &provider, &endpoint);
|
||||
let base_changed = {
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, &endpoint, force_cloud_base_url)
|
||||
};
|
||||
|
||||
let key = match api_key {
|
||||
Some(value) if !value.trim().is_empty() => value,
|
||||
@@ -95,10 +114,6 @@ async fn setup(provider: String, api_key: Option<String>, endpoint: Option<Strin
|
||||
entry.api_key = Some(key.clone());
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.base_url = Some(endpoint.clone());
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Saved Ollama configuration for provider '{provider}'.");
|
||||
if config.privacy.encrypt_local_data {
|
||||
@@ -106,6 +121,12 @@ async fn setup(provider: String, api_key: Option<String>, endpoint: Option<Strin
|
||||
} else {
|
||||
println!("API key stored in plaintext configuration (encryption disabled).");
|
||||
}
|
||||
if !force_cloud_base_url && !base_changed {
|
||||
println!(
|
||||
"Local base URL preserved; cloud endpoint stored as {}.",
|
||||
CLOUD_ENDPOINT_KEY
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -120,25 +141,32 @@ async fn status(provider: String) -> Result<()> {
|
||||
};
|
||||
|
||||
let api_key = hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
ensure_provider_entry(&mut config, &provider, DEFAULT_CLOUD_ENDPOINT);
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint.clone());
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.health_check().await {
|
||||
Ok(_) => {
|
||||
println!(
|
||||
"✓ Connected to {provider} ({})",
|
||||
provider_cfg
|
||||
.base_url
|
||||
.as_deref()
|
||||
.unwrap_or(DEFAULT_CLOUD_ENDPOINT)
|
||||
);
|
||||
println!("✓ Connected to {provider} ({})", endpoint);
|
||||
if api_key.is_none() && config.privacy.encrypt_local_data {
|
||||
println!(
|
||||
"Warning: No API key stored; connection succeeded via environment variables."
|
||||
@@ -164,13 +192,27 @@ async fn models(provider: String) -> Result<()> {
|
||||
};
|
||||
hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
|
||||
ensure_provider_entry(&mut config, &provider, DEFAULT_CLOUD_ENDPOINT);
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint);
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.list_models().await {
|
||||
@@ -208,8 +250,9 @@ async fn logout(provider: String) -> Result<()> {
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(entry) = provider_entry_mut(&mut config) {
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
entry.enabled = false;
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
@@ -217,33 +260,70 @@ async fn logout(provider: String) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_provider_entry(config: &mut Config, provider: &str, endpoint: &str) {
|
||||
if provider == "ollama"
|
||||
&& config.providers.contains_key("ollama-cloud")
|
||||
&& !config.providers.contains_key("ollama")
|
||||
&& let Some(mut legacy) = config.providers.remove("ollama-cloud")
|
||||
{
|
||||
legacy.provider_type = "ollama".to_string();
|
||||
config.providers.insert("ollama".to_string(), legacy);
|
||||
fn ensure_provider_entry<'a>(config: &'a mut Config, provider: &str) -> &'a mut ProviderConfig {
|
||||
core_config::ensure_provider_config_mut(config, provider)
|
||||
}
|
||||
|
||||
fn configure_cloud_endpoint(entry: &mut ProviderConfig, endpoint: &str, force: bool) -> bool {
|
||||
let normalized = normalize_endpoint(endpoint);
|
||||
let previous_base = entry.base_url.clone();
|
||||
entry.extra.insert(
|
||||
CLOUD_ENDPOINT_KEY.to_string(),
|
||||
Value::String(normalized.clone()),
|
||||
);
|
||||
|
||||
if entry.api_key_env.is_none() {
|
||||
entry.api_key_env = Some(OLLAMA_CLOUD_API_KEY_ENV.to_string());
|
||||
}
|
||||
|
||||
core_config::ensure_provider_config(config, provider);
|
||||
if force
|
||||
|| entry
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
entry.base_url = Some(normalized.clone());
|
||||
}
|
||||
|
||||
if let Some(cfg) = config.providers.get_mut(provider) {
|
||||
if cfg.provider_type != "ollama" {
|
||||
cfg.provider_type = "ollama".to_string();
|
||||
}
|
||||
if cfg.base_url.is_none() {
|
||||
cfg.base_url = Some(endpoint.to_string());
|
||||
}
|
||||
if force {
|
||||
entry.enabled = true;
|
||||
}
|
||||
|
||||
entry.base_url != previous_base
|
||||
}
|
||||
|
||||
fn resolve_cloud_endpoint(cfg: &ProviderConfig) -> Option<String> {
|
||||
if let Some(value) = cfg
|
||||
.extra
|
||||
.get(CLOUD_ENDPOINT_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.map(normalize_endpoint)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
|
||||
cfg.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim_end_matches('/').to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
fn normalize_endpoint(endpoint: &str) -> String {
|
||||
let trimmed = endpoint.trim().trim_end_matches('/');
|
||||
if trimmed.is_empty() {
|
||||
DEFAULT_CLOUD_ENDPOINT.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_name(provider: &str) -> String {
|
||||
let normalized = provider.trim().replace('_', "-").to_ascii_lowercase();
|
||||
let normalized = provider.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama".to_string(),
|
||||
"ollama-cloud" => "ollama".to_string(),
|
||||
"" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama_cloud" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
value => value.to_string(),
|
||||
}
|
||||
}
|
||||
@@ -269,21 +349,6 @@ fn set_env_if_missing(var: &str, value: &str) {
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_entry_mut(config: &mut Config) -> Option<&mut ProviderConfig> {
|
||||
if config.providers.contains_key("ollama") {
|
||||
config.providers.get_mut("ollama")
|
||||
} else {
|
||||
config.providers.get_mut("ollama-cloud")
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_entry(config: &Config) -> Option<&ProviderConfig> {
|
||||
if let Some(entry) = config.providers.get("ollama") {
|
||||
return Some(entry);
|
||||
}
|
||||
config.providers.get("ollama-cloud")
|
||||
}
|
||||
|
||||
fn unlock_credential_manager(
|
||||
config: &Config,
|
||||
storage: Arc<StorageManager>,
|
||||
@@ -315,8 +380,10 @@ fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||
use std::env;
|
||||
|
||||
if path.exists() {
|
||||
if let Ok(password) = env::var("OWLEN_MASTER_PASSWORD")
|
||||
&& !password.trim().is_empty()
|
||||
if let Some(password) = env::var("OWLEN_MASTER_PASSWORD")
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|password| !password.is_empty())
|
||||
{
|
||||
return encryption::unlock_with_password(path.to_path_buf(), &password)
|
||||
.context("Failed to unlock vault with OWLEN_MASTER_PASSWORD");
|
||||
@@ -356,30 +423,28 @@ async fn hydrate_api_key(
|
||||
config: &mut Config,
|
||||
manager: Option<&Arc<CredentialManager>>,
|
||||
) -> Result<Option<String>> {
|
||||
if let Some(manager) = manager
|
||||
&& let Some(credentials) = manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?
|
||||
{
|
||||
let credentials = match manager {
|
||||
Some(manager) => manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?,
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(credentials) = credentials {
|
||||
let key = credentials.api_key.trim().to_string();
|
||||
if !key.is_empty() {
|
||||
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||
}
|
||||
|
||||
if let Some(cfg) = provider_entry_mut(config)
|
||||
&& cfg.base_url.is_none()
|
||||
&& !credentials.endpoint.trim().is_empty()
|
||||
{
|
||||
cfg.base_url = Some(credentials.endpoint);
|
||||
}
|
||||
let cfg = core_config::ensure_provider_config_mut(config, CLOUD_PROVIDER_KEY);
|
||||
configure_cloud_endpoint(cfg, &credentials.endpoint, false);
|
||||
return Ok(Some(key));
|
||||
}
|
||||
|
||||
if let Some(cfg) = provider_entry(config)
|
||||
&& let Some(key) = cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
if let Some(key) = config
|
||||
.provider(CLOUD_PROVIDER_KEY)
|
||||
.and_then(|cfg| cfg.api_key.as_ref())
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||
@@ -407,8 +472,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn canonicalises_provider_names() {
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), "ollama");
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), "ollama");
|
||||
assert_eq!(canonical_provider_name(""), "ollama");
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(""), CLOUD_PROVIDER_KEY);
|
||||
}
|
||||
}
|
||||
4
crates/owlen-cli/src/commands/mod.rs
Normal file
4
crates/owlen-cli/src/commands/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Command implementations for the `owlen` CLI.
|
||||
|
||||
pub mod cloud;
|
||||
pub mod providers;
|
||||
651
crates/owlen-cli/src/commands/providers.rs
Normal file
651
crates/owlen-cli/src/commands/providers.rs
Normal file
@@ -0,0 +1,651 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand};
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{self as core_config, Config};
|
||||
use owlen_core::provider::{
|
||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::storage::StorageManager;
|
||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
use super::cloud;
|
||||
|
||||
/// CLI subcommands for provider management.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ProvidersCommand {
|
||||
/// List configured providers and their metadata.
|
||||
List,
|
||||
/// Run health checks against providers.
|
||||
Status {
|
||||
/// Optional provider identifier to check.
|
||||
#[arg(value_name = "PROVIDER")]
|
||||
provider: Option<String>,
|
||||
},
|
||||
/// Enable a provider in the configuration.
|
||||
Enable {
|
||||
/// Provider identifier to enable.
|
||||
provider: String,
|
||||
},
|
||||
/// Disable a provider in the configuration.
|
||||
Disable {
|
||||
/// Provider identifier to disable.
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Arguments for the `owlen models` command.
|
||||
#[derive(Debug, Default, Args)]
|
||||
pub struct ModelsArgs {
|
||||
/// Restrict output to a specific provider.
|
||||
#[arg(long)]
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn run_providers_command(command: ProvidersCommand) -> Result<()> {
|
||||
match command {
|
||||
ProvidersCommand::List => list_providers(),
|
||||
ProvidersCommand::Status { provider } => status_providers(provider.as_deref()).await,
|
||||
ProvidersCommand::Enable { provider } => toggle_provider(&provider, true),
|
||||
ProvidersCommand::Disable { provider } => toggle_provider(&provider, false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_models_command(args: ModelsArgs) -> Result<()> {
|
||||
list_models(args.provider.as_deref()).await
|
||||
}
|
||||
|
||||
fn list_providers() -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for (id, cfg) in &config.providers {
|
||||
let type_label = describe_provider_type(id, cfg);
|
||||
let auth_label = describe_auth(cfg, requires_auth(id, cfg));
|
||||
let enabled = if cfg.enabled { "yes" } else { "no" };
|
||||
let default = if id == &default_provider { "*" } else { "" };
|
||||
let base = cfg
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().to_string())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
|
||||
rows.push(ProviderListRow {
|
||||
id: id.to_string(),
|
||||
type_label,
|
||||
enabled: enabled.to_string(),
|
||||
default: default.to_string(),
|
||||
auth: auth_label,
|
||||
base_url: base,
|
||||
});
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let enabled_width = rows
|
||||
.iter()
|
||||
.map(|row| row.enabled.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Enabled".len());
|
||||
let default_width = rows
|
||||
.iter()
|
||||
.map(|row| row.default.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Default".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.type_label.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let auth_width = rows
|
||||
.iter()
|
||||
.map(|row| row.auth.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Auth".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} Base URL",
|
||||
"Provider",
|
||||
"Enabled",
|
||||
"Default",
|
||||
"Type",
|
||||
"Auth",
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} {}",
|
||||
row.id,
|
||||
row.enabled,
|
||||
row.default,
|
||||
row.type_label,
|
||||
row.auth,
|
||||
row.base_url,
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status_providers(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let health = manager.refresh_health().await;
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for record in records {
|
||||
let status = health.get(&record.id).copied();
|
||||
rows.push(ProviderStatusRow::from_record(record, status));
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
print_status_rows(&rows);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_models(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let models = manager
|
||||
.list_all_models()
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let statuses = manager.provider_statuses().await;
|
||||
|
||||
print_models(records, models, statuses);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_provider_filter(config: &Config, filter: Option<&str>) -> Result<()> {
|
||||
if let Some(filter) = filter
|
||||
&& !config.providers.contains_key(filter)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Provider '{}' is not defined in configuration.",
|
||||
filter
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn toggle_provider(provider: &str, enable: bool) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let canonical = canonical_provider_id(provider);
|
||||
if canonical.is_empty() {
|
||||
return Err(anyhow!("Provider name cannot be empty."));
|
||||
}
|
||||
|
||||
let previous_default = config.general.default_provider.clone();
|
||||
let previous_fallback_enabled = config.providers.get("ollama_local").map(|cfg| cfg.enabled);
|
||||
|
||||
let previous_enabled;
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
previous_enabled = entry.enabled;
|
||||
if previous_enabled == enable {
|
||||
println!(
|
||||
"Provider '{}' is already {}.",
|
||||
canonical,
|
||||
if enable { "enabled" } else { "disabled" }
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
entry.enabled = enable;
|
||||
}
|
||||
|
||||
if !enable && config.general.default_provider == canonical {
|
||||
if let Some(candidate) = choose_fallback_provider(&config, &canonical) {
|
||||
config.general.default_provider = candidate.clone();
|
||||
println!(
|
||||
"Default provider set to '{}' because '{}' was disabled.",
|
||||
candidate, canonical
|
||||
);
|
||||
} else {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
entry.enabled = true;
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
println!(
|
||||
"Enabled 'ollama_local' and made it default because no other providers are active."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = config.validate() {
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
entry.enabled = previous_enabled;
|
||||
}
|
||||
config.general.default_provider = previous_default;
|
||||
if let Some(enabled) = previous_fallback_enabled
|
||||
&& let Some(entry) = config.providers.get_mut("ollama_local")
|
||||
{
|
||||
entry.enabled = enabled;
|
||||
}
|
||||
return Err(anyhow!(err));
|
||||
}
|
||||
|
||||
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||
|
||||
println!(
|
||||
"{} provider '{}'.",
|
||||
if enable { "Enabled" } else { "Disabled" },
|
||||
canonical
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn choose_fallback_provider(config: &Config, exclude: &str) -> Option<String> {
|
||||
if exclude != "ollama_local"
|
||||
&& let Some(cfg) = config.providers.get("ollama_local")
|
||||
&& cfg.enabled
|
||||
{
|
||||
return Some("ollama_local".to_string());
|
||||
}
|
||||
|
||||
let mut candidates: Vec<String> = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && id.as_str() != exclude)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
candidates.sort();
|
||||
candidates.into_iter().next()
|
||||
}
|
||||
|
||||
async fn register_enabled_providers(
|
||||
manager: &ProviderManager,
|
||||
config: &Config,
|
||||
filter: Option<&str>,
|
||||
) -> Result<Vec<ProviderRecord>> {
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
let mut records = Vec::new();
|
||||
|
||||
for (id, cfg) in &config.providers {
|
||||
if let Some(filter) = filter
|
||||
&& id != filter
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut record = ProviderRecord::from_config(id, cfg, id == &default_provider);
|
||||
if !cfg.enabled {
|
||||
records.push(record);
|
||||
continue;
|
||||
}
|
||||
|
||||
match instantiate_provider(id, cfg) {
|
||||
Ok(provider) => {
|
||||
let metadata = provider.metadata().clone();
|
||||
record.provider_type_label = provider_type_label(metadata.provider_type);
|
||||
record.requires_auth = metadata.requires_auth;
|
||||
record.metadata = Some(metadata);
|
||||
manager.register_provider(provider).await;
|
||||
}
|
||||
Err(err) => {
|
||||
record.registration_error = Some(err.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
records.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn instantiate_provider(id: &str, cfg: &ProviderConfig) -> Result<Arc<dyn ModelProvider>> {
|
||||
let kind = cfg.provider_type.trim().to_ascii_lowercase();
|
||||
if kind == "ollama" || id == "ollama_local" {
|
||||
let provider = OllamaLocalProvider::new(cfg.base_url.clone(), None, None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else if kind == "ollama_cloud" || id == "ollama_cloud" {
|
||||
let provider = OllamaCloudProvider::new(cfg.base_url.clone(), cfg.api_key.clone(), None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Provider '{}' uses unsupported type '{}'.",
|
||||
id,
|
||||
if kind.is_empty() {
|
||||
"unknown"
|
||||
} else {
|
||||
kind.as_str()
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn describe_provider_type(id: &str, cfg: &ProviderConfig) -> String {
|
||||
if cfg.provider_type.trim().eq_ignore_ascii_case("ollama") || id.ends_with("_local") {
|
||||
"Local".to_string()
|
||||
} else if cfg
|
||||
.provider_type
|
||||
.trim()
|
||||
.eq_ignore_ascii_case("ollama_cloud")
|
||||
|| id.contains("cloud")
|
||||
{
|
||||
"Cloud".to_string()
|
||||
} else {
|
||||
"Custom".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_auth(id: &str, cfg: &ProviderConfig) -> bool {
|
||||
cfg.api_key.is_some()
|
||||
|| cfg.api_key_env.is_some()
|
||||
|| matches!(id, "ollama_cloud" | "openai" | "anthropic")
|
||||
}
|
||||
|
||||
fn describe_auth(cfg: &ProviderConfig, required: bool) -> String {
|
||||
if let Some(env) = cfg
|
||||
.api_key_env
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
format!("env:{env}")
|
||||
} else if cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| !value.trim().is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"config".to_string()
|
||||
} else if required {
|
||||
"required".to_string()
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_id(raw: &str) -> String {
|
||||
let trimmed = raw.trim().to_ascii_lowercase();
|
||||
if trimmed.is_empty() {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
match trimmed.as_str() {
|
||||
"ollama" | "ollama-local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" | "ollama-cloud" => "ollama_cloud".to_string(),
|
||||
other => other.replace('-', "_"),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_type_label(provider_type: ProviderType) -> String {
|
||||
match provider_type {
|
||||
ProviderType::Local => "Local".to_string(),
|
||||
ProviderType::Cloud => "Cloud".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_status_strings(status: ProviderStatus) -> (&'static str, &'static str) {
|
||||
match status {
|
||||
ProviderStatus::Available => ("OK", "available"),
|
||||
ProviderStatus::Unavailable => ("ERR", "unavailable"),
|
||||
ProviderStatus::RequiresSetup => ("SETUP", "requires setup"),
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_rows(rows: &[ProviderStatusRow]) {
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.provider_type.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let status_width = rows
|
||||
.iter()
|
||||
.map(|row| row.indicator.len() + 1 + row.status_label.len())
|
||||
.max()
|
||||
.unwrap_or(6)
|
||||
.max("State".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} Details",
|
||||
"Provider",
|
||||
"Def",
|
||||
"Type",
|
||||
"State",
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
let def = if row.default_provider { "*" } else { "-" };
|
||||
let details = row.detail.as_deref().unwrap_or("-");
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} {}",
|
||||
row.id,
|
||||
def,
|
||||
row.provider_type,
|
||||
format!("{} {}", row.indicator, row.status_label),
|
||||
details,
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_models(
|
||||
records: Vec<ProviderRecord>,
|
||||
models: Vec<AnnotatedModelInfo>,
|
||||
statuses: HashMap<String, ProviderStatus>,
|
||||
) {
|
||||
let mut grouped: HashMap<String, Vec<AnnotatedModelInfo>> = HashMap::new();
|
||||
for info in models {
|
||||
grouped
|
||||
.entry(info.provider_id.clone())
|
||||
.or_default()
|
||||
.push(info);
|
||||
}
|
||||
|
||||
for record in records {
|
||||
let status = statuses.get(&record.id).copied().or_else(|| {
|
||||
if record.metadata.is_some() && record.registration_error.is_none() && record.enabled {
|
||||
Some(ProviderStatus::Unavailable)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let (indicator, label, status_value) = if !record.enabled {
|
||||
("-", "disabled", None)
|
||||
} else if record.registration_error.is_some() {
|
||||
("ERR", "error", None)
|
||||
} else if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
(indicator, label, Some(status))
|
||||
} else {
|
||||
("?", "unknown", None)
|
||||
};
|
||||
|
||||
let title = if record.default_provider {
|
||||
format!("{} (default)", record.id)
|
||||
} else {
|
||||
record.id.clone()
|
||||
};
|
||||
println!(
|
||||
"{} {} [{}] {}",
|
||||
indicator, title, record.provider_type_label, label
|
||||
);
|
||||
|
||||
if let Some(err) = &record.registration_error {
|
||||
println!(" error: {}", err);
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !record.enabled {
|
||||
println!(" provider disabled");
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entries) = grouped.get(&record.id) {
|
||||
let mut entries = entries.clone();
|
||||
entries.sort_by(|a, b| a.model.name.cmp(&b.model.name));
|
||||
if entries.is_empty() {
|
||||
println!(" (no models reported)");
|
||||
} else {
|
||||
for entry in entries {
|
||||
let mut line = format!(" - {}", entry.model.name);
|
||||
if let Some(description) = &entry.model.description
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
line.push_str(&format!(" — {}", description.trim()));
|
||||
}
|
||||
println!("{}", line);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" (no models reported)");
|
||||
}
|
||||
|
||||
if let Some(ProviderStatus::RequiresSetup) = status_value
|
||||
&& record.requires_auth
|
||||
{
|
||||
println!(" configure provider credentials or API key");
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderListRow {
|
||||
id: String,
|
||||
type_label: String,
|
||||
enabled: String,
|
||||
default: String,
|
||||
auth: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
struct ProviderRecord {
|
||||
id: String,
|
||||
enabled: bool,
|
||||
default_provider: bool,
|
||||
provider_type_label: String,
|
||||
requires_auth: bool,
|
||||
registration_error: Option<String>,
|
||||
metadata: Option<owlen_core::provider::ProviderMetadata>,
|
||||
}
|
||||
|
||||
impl ProviderRecord {
|
||||
fn from_config(id: &str, cfg: &ProviderConfig, default_provider: bool) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
enabled: cfg.enabled,
|
||||
default_provider,
|
||||
provider_type_label: describe_provider_type(id, cfg),
|
||||
requires_auth: requires_auth(id, cfg),
|
||||
registration_error: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderStatusRow {
|
||||
id: String,
|
||||
provider_type: String,
|
||||
default_provider: bool,
|
||||
indicator: String,
|
||||
status_label: String,
|
||||
detail: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderStatusRow {
|
||||
fn from_record(record: ProviderRecord, status: Option<ProviderStatus>) -> Self {
|
||||
if !record.enabled {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "-".to_string(),
|
||||
status_label: "disabled".to_string(),
|
||||
detail: None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(err) = record.registration_error {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "ERR".to_string(),
|
||||
status_label: "error".to_string(),
|
||||
detail: Some(err),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: indicator.to_string(),
|
||||
status_label: label.to_string(),
|
||||
detail: if matches!(status, ProviderStatus::RequiresSetup) && record.requires_auth {
|
||||
Some("credentials required".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "?".to_string(),
|
||||
status_label: "unknown".to_string(),
|
||||
detail: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,38 +1,22 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once Rust 2024 let-chains are available
|
||||
|
||||
//! OWLEN CLI - Chat TUI client
|
||||
|
||||
mod cloud;
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use cloud::{CloudCommand, load_runtime_credentials, set_env_var};
|
||||
use commands::{
|
||||
cloud::{CloudCommand, run_cloud_command},
|
||||
providers::{ModelsArgs, ProvidersCommand, run_models_command, run_providers_command},
|
||||
};
|
||||
use mcp::{McpCommand, run_mcp_command};
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
providers::OllamaProvider,
|
||||
session::SessionController,
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
};
|
||||
use owlen_tui::tui_controller::{TuiController, TuiRequest};
|
||||
use owlen_tui::{AppState, ChatApp, Event, EventHandler, SessionEvent, config, ui};
|
||||
use std::any::Any;
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures::stream;
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use owlen_core::config::McpMode;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
/// Owlen - Terminal UI for LLM chat
|
||||
#[derive(Parser, Debug)]
|
||||
@@ -54,6 +38,14 @@ enum OwlenCommand {
|
||||
/// Manage Ollama Cloud credentials
|
||||
#[command(subcommand)]
|
||||
Cloud(CloudCommand),
|
||||
/// Manage model providers
|
||||
#[command(subcommand)]
|
||||
Providers(ProvidersCommand),
|
||||
/// List models exposed by configured providers
|
||||
Models(ModelsArgs),
|
||||
/// Manage MCP server registrations
|
||||
#[command(subcommand)]
|
||||
Mcp(McpCommand),
|
||||
/// Show manual steps for updating Owlen to the latest revision
|
||||
Upgrade,
|
||||
}
|
||||
@@ -66,70 +58,13 @@ enum ConfigCommand {
|
||||
Path,
|
||||
}
|
||||
|
||||
fn build_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.mcp_servers.first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => {
|
||||
let provider: Arc<dyn Provider> = Arc::new(client);
|
||||
Ok(provider)
|
||||
}
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow::Error::from(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.mcp_servers.first().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\""
|
||||
)
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||
let provider: Arc<dyn Provider> = Arc::new(client);
|
||||
Ok(provider)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow::anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> anyhow::Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow::anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama-cloud" => {
|
||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow::anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||
match command {
|
||||
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||
OwlenCommand::Cloud(cloud_cmd) => cloud::run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Cloud(cloud_cmd) => run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Providers(provider_cmd) => run_providers_command(provider_cmd).await,
|
||||
OwlenCommand::Models(args) => run_models_command(args).await,
|
||||
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||
OwlenCommand::Upgrade => {
|
||||
println!(
|
||||
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||
@@ -157,46 +92,85 @@ fn run_config_doctor() -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
let existed = config_path.exists();
|
||||
let mut config = config::try_load_config().unwrap_or_default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
let mut changes = Vec::new();
|
||||
|
||||
if !existed {
|
||||
changes.push("created configuration file from defaults".to_string());
|
||||
}
|
||||
|
||||
if !config
|
||||
.providers
|
||||
.contains_key(&config.general.default_provider)
|
||||
{
|
||||
config.general.default_provider = "ollama".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama'".to_string());
|
||||
if config.provider(&config.general.default_provider).is_none() {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama_local'".to_string());
|
||||
}
|
||||
|
||||
if let Some(mut legacy) = config.providers.remove("ollama-cloud") {
|
||||
legacy.provider_type = "ollama".to_string();
|
||||
use std::collections::hash_map::Entry;
|
||||
match config.providers.entry("ollama".to_string()) {
|
||||
Entry::Occupied(mut existing) => {
|
||||
let entry = existing.get_mut();
|
||||
if entry.api_key.is_none() {
|
||||
entry.api_key = legacy.api_key.take();
|
||||
for key in ["ollama_local", "ollama_cloud", "openai", "anthropic"] {
|
||||
if !config.providers.contains_key(key) {
|
||||
core_config::ensure_provider_config_mut(&mut config, key);
|
||||
changes.push(format!("added default configuration for provider '{key}'"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut("ollama_local") {
|
||||
if entry.provider_type.trim().is_empty() || entry.provider_type != "ollama" {
|
||||
entry.provider_type = "ollama".to_string();
|
||||
changes.push("normalised providers.ollama_local.provider_type to 'ollama'".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let mut ensure_default_enabled = true;
|
||||
|
||||
if !config.providers.values().any(|cfg| cfg.enabled) {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push("no providers were enabled; enabled 'ollama_local'".to_string());
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because no providers were enabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
ensure_default_enabled = false;
|
||||
}
|
||||
|
||||
if ensure_default_enabled {
|
||||
let default_id = config.general.default_provider.clone();
|
||||
if let Some(default_cfg) = config.providers.get(&default_id) {
|
||||
if !default_cfg.enabled {
|
||||
if let Some(new_default) = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && *id != &default_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.min()
|
||||
{
|
||||
config.general.default_provider = new_default.clone();
|
||||
changes.push(format!(
|
||||
"default provider '{default_id}' was disabled; switched default to '{new_default}'"
|
||||
));
|
||||
} else {
|
||||
let entry =
|
||||
core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push(
|
||||
"enabled 'ollama_local' because default provider was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because previous default was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
if entry.base_url.is_none() && legacy.base_url.is_some() {
|
||||
entry.base_url = legacy.base_url.take();
|
||||
}
|
||||
entry.extra.extend(legacy.extra);
|
||||
}
|
||||
Entry::Vacant(slot) => {
|
||||
slot.insert(legacy);
|
||||
}
|
||||
}
|
||||
changes.push(
|
||||
"migrated legacy 'ollama-cloud' provider into unified 'ollama' entry".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
if !config.providers.contains_key("ollama") {
|
||||
core_config::ensure_provider_config(&mut config, "ollama");
|
||||
changes.push("added default ollama provider configuration".to_string());
|
||||
}
|
||||
|
||||
match config.mcp.mode {
|
||||
@@ -205,7 +179,7 @@ fn run_config_doctor() -> Result<()> {
|
||||
config.mcp.warn_on_legacy = true;
|
||||
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||
}
|
||||
McpMode::RemoteOnly if config.mcp_servers.is_empty() => {
|
||||
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
@@ -213,7 +187,9 @@ fn run_config_doctor() -> Result<()> {
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
McpMode::RemotePreferred if !config.mcp.allow_fallback && config.mcp_servers.is_empty() => {
|
||||
McpMode::RemotePreferred
|
||||
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||
{
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||
@@ -240,120 +216,6 @@ fn run_config_doctor() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
// Parse command-line arguments
|
||||
@@ -362,170 +224,5 @@ async fn main() -> Result<()> {
|
||||
return run_command(command).await;
|
||||
}
|
||||
let initial_mode = if code { Mode::Code } else { Mode::Chat };
|
||||
|
||||
// Set auto-consent for TUI mode to prevent blocking stdin reads
|
||||
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
// Load configuration (or fall back to defaults) for the session controller.
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
// Create provider according to MCP configuration (supports legacy/local fallback)
|
||||
let provider = build_provider(&cfg)?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.mcp_servers.is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let controller =
|
||||
SessionController::new(provider, cfg, storage.clone(), tui_controller, false).await?;
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
// Set the initial mode
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
// Event infrastructure
|
||||
let cancellation_token = CancellationToken::new();
|
||||
let (event_tx, event_rx) = mpsc::unbounded_channel();
|
||||
let event_handler = EventHandler::new(event_tx, cancellation_token.clone());
|
||||
let event_handle = tokio::spawn(async move { event_handler.run().await });
|
||||
|
||||
// Terminal setup
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut app, event_rx, &mut session_rx).await;
|
||||
|
||||
// Shutdown
|
||||
cancellation_token.cancel();
|
||||
event_handle.await?;
|
||||
|
||||
// Persist configuration updates (e.g., selected model)
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
app: &mut ChatApp,
|
||||
mut event_rx: mpsc::UnboundedReceiver<Event>,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
let stream_draw_interval = tokio::time::Duration::from_millis(50);
|
||||
let idle_tick = tokio::time::Duration::from_millis(100);
|
||||
let mut last_draw = tokio::time::Instant::now() - stream_draw_interval;
|
||||
|
||||
loop {
|
||||
// Advance loading animation frame
|
||||
app.advance_loading_animation();
|
||||
|
||||
let streaming_active = app.streaming_count() > 0;
|
||||
let draw_due = if streaming_active {
|
||||
last_draw.elapsed() >= stream_draw_interval
|
||||
} else {
|
||||
true
|
||||
};
|
||||
|
||||
if draw_due {
|
||||
terminal.draw(|f| ui::render_chat(f, app))?;
|
||||
last_draw = tokio::time::Instant::now();
|
||||
}
|
||||
|
||||
// Process any pending LLM requests AFTER UI has been drawn
|
||||
if let Err(e) = app.process_pending_llm_request().await {
|
||||
eprintln!("Error processing LLM request: {}", e);
|
||||
}
|
||||
|
||||
// Process any pending tool executions AFTER UI has been drawn
|
||||
if let Err(e) = app.process_pending_tool_execution().await {
|
||||
eprintln!("Error processing tool execution: {}", e);
|
||||
}
|
||||
|
||||
let sleep_duration = if streaming_active {
|
||||
stream_draw_interval
|
||||
.checked_sub(last_draw.elapsed())
|
||||
.unwrap_or_else(|| tokio::time::Duration::from_millis(0))
|
||||
} else {
|
||||
idle_tick
|
||||
};
|
||||
|
||||
tokio::select! {
|
||||
Some(event) = event_rx.recv() => {
|
||||
if let AppState::Quit = app.handle_event(event).await? {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
Some(session_event) = session_rx.recv() => {
|
||||
app.handle_session_event(session_event)?;
|
||||
}
|
||||
_ = tokio::time::sleep(sleep_duration) => {}
|
||||
}
|
||||
}
|
||||
bootstrap::launch(initial_mode).await
|
||||
}
|
||||
|
||||
259
crates/owlen-cli/src/mcp.rs
Normal file
259
crates/owlen-cli/src/mcp.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum McpCommand {
|
||||
/// Add or update an MCP server in the selected scope
|
||||
Add(AddArgs),
|
||||
/// List MCP servers across scopes
|
||||
List(ListArgs),
|
||||
/// Remove an MCP server from a scope
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||
match command {
|
||||
McpCommand::Add(args) => handle_add(args),
|
||||
McpCommand::List(args) => handle_list(args),
|
||||
McpCommand::Remove(args) => handle_remove(args),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
pub enum ScopeArg {
|
||||
User,
|
||||
#[default]
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
impl From<ScopeArg> for McpConfigScope {
|
||||
fn from(value: ScopeArg) -> Self {
|
||||
match value {
|
||||
ScopeArg::User => McpConfigScope::User,
|
||||
ScopeArg::Project => McpConfigScope::Project,
|
||||
ScopeArg::Local => McpConfigScope::Local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AddArgs {
|
||||
/// Logical name used to reference the server
|
||||
pub name: String,
|
||||
/// Command or endpoint invoked for the server
|
||||
pub command: String,
|
||||
/// Transport mechanism (stdio, http, websocket)
|
||||
#[arg(long, default_value = "stdio")]
|
||||
pub transport: String,
|
||||
/// Configuration scope to write the server into
|
||||
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||
pub scope: ScopeArg,
|
||||
/// Environment variables (KEY=VALUE) passed to the server process
|
||||
#[arg(long = "env")]
|
||||
pub env: Vec<String>,
|
||||
/// Additional arguments appended when launching the server
|
||||
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Default)]
|
||||
pub struct ListArgs {
|
||||
/// Restrict output to a specific configuration scope
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
/// Display only the effective servers (after precedence resolution)
|
||||
#[arg(long)]
|
||||
pub effective_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the server to remove
|
||||
pub name: String,
|
||||
/// Optional explicit scope to remove from
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
}
|
||||
|
||||
fn handle_add(args: AddArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope: McpConfigScope = args.scope.into();
|
||||
let mut env_map = HashMap::new();
|
||||
for pair in &args.env {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||
if key.trim().is_empty() {
|
||||
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||
}
|
||||
env_map.insert(key.trim().to_string(), value.to_string());
|
||||
}
|
||||
|
||||
let server = McpServerConfig {
|
||||
name: args.name.clone(),
|
||||
command: args.command.clone(),
|
||||
args: args.args.clone(),
|
||||
transport: args.transport.to_lowercase(),
|
||||
env: env_map,
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
config.add_mcp_server(scope, server.clone(), None)?;
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope ({})",
|
||||
server.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope.",
|
||||
server.name, scope
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_list(args: ListArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
|
||||
let scoped = config.scoped_mcp_servers();
|
||||
if scoped.is_empty() {
|
||||
println!("No MCP servers configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter_scope = args.scope.map(|scope| scope.into());
|
||||
let effective = config.effective_mcp_servers();
|
||||
let mut active = HashSet::new();
|
||||
for server in effective {
|
||||
active.insert((
|
||||
server.name.clone(),
|
||||
server.command.clone(),
|
||||
server.transport.to_lowercase(),
|
||||
));
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||
"", "Scope", "Name", "Transport"
|
||||
);
|
||||
for entry in scoped {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||
let key = (
|
||||
entry.config.name.clone(),
|
||||
entry.config.command.clone(),
|
||||
entry.config.transport.to_lowercase(),
|
||||
);
|
||||
let marker = if active.contains(&key) { "*" } else { " " };
|
||||
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(
|
||||
"{} {:<8} {:<20} {:<10} {}",
|
||||
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||
);
|
||||
}
|
||||
|
||||
let scoped_resources = config.scoped_mcp_resources();
|
||||
if !scoped_resources.is_empty() {
|
||||
println!();
|
||||
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||
let effective_keys: HashSet<(String, String)> = config
|
||||
.effective_mcp_resources()
|
||||
.iter()
|
||||
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||
.collect();
|
||||
|
||||
for entry in scoped_resources {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
let marker = if effective_keys.contains(&key) {
|
||||
"*"
|
||||
} else {
|
||||
" "
|
||||
};
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||
|
||||
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope_hint = args.scope.map(|scope| scope.into());
|
||||
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||
|
||||
match result {
|
||||
Some(scope) => {
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Removed MCP server '{}' from {} scope ({})",
|
||||
args.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
println!("No MCP server named '{}' was found.", args.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||
if args.is_empty() {
|
||||
command.to_string()
|
||||
} else {
|
||||
format!("{} {}", command, args.join(" "))
|
||||
}
|
||||
}
|
||||
@@ -50,3 +50,4 @@ ollama-rs = { version = "0.3", features = ["stream", "headers"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
httpmock = "0.7"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -58,9 +58,14 @@ impl ConsentManager {
|
||||
/// Load consent records from vault storage
|
||||
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
||||
let guard = vault.lock().expect("Vault mutex poisoned");
|
||||
if let Some(consent_data) = guard.settings().get("consent_records")
|
||||
&& let Ok(permanent_records) =
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
if let Some(permanent_records) =
|
||||
guard
|
||||
.settings()
|
||||
.get("consent_records")
|
||||
.and_then(|consent_data| {
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
.ok()
|
||||
})
|
||||
{
|
||||
return Self {
|
||||
permanent_records,
|
||||
@@ -90,15 +95,19 @@ impl ConsentManager {
|
||||
endpoints: Vec<String>,
|
||||
) -> Result<ConsentScope> {
|
||||
// Check if already granted permanently
|
||||
if let Some(existing) = self.permanent_records.get(tool_name)
|
||||
&& existing.scope == ConsentScope::Permanent
|
||||
if self
|
||||
.permanent_records
|
||||
.get(tool_name)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||
{
|
||||
return Ok(ConsentScope::Permanent);
|
||||
}
|
||||
|
||||
// Check if granted for session
|
||||
if let Some(existing) = self.session_records.get(tool_name)
|
||||
&& existing.scope == ConsentScope::Session
|
||||
if self
|
||||
.session_records
|
||||
.get(tool_name)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||
{
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, storage::StorageManager};
|
||||
use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ApiCredentials {
|
||||
@@ -31,6 +31,10 @@ impl CredentialManager {
|
||||
format!("{}_{}", self.namespace, tool_name)
|
||||
}
|
||||
|
||||
fn oauth_storage_key(&self, resource: &str) -> String {
|
||||
self.namespaced_key(&format!("oauth_{resource}"))
|
||||
}
|
||||
|
||||
pub async fn store_credentials(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
@@ -68,4 +72,37 @@ impl CredentialManager {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
|
||||
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let payload = serde_json::to_vec(token).map_err(|err| {
|
||||
Error::Storage(format!(
|
||||
"Failed to serialize OAuth token for secure storage: {err}"
|
||||
))
|
||||
})?;
|
||||
self.storage
|
||||
.store_secure_item(&key, &payload, &self.master_key)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let raw = self
|
||||
.storage
|
||||
.load_secure_item(&key, &self.master_key)
|
||||
.await?;
|
||||
if let Some(bytes) = raw {
|
||||
let token = serde_json::from_slice(&bytes).map_err(|err| {
|
||||
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
|
||||
})?;
|
||||
Ok(Some(token))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
}
|
||||
|
||||
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
32
crates/owlen-core/src/facade/llm_client.rs
Normal file
@@ -0,0 +1,32 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
Result,
|
||||
llm::ChatStream,
|
||||
mcp::{McpToolCall, McpToolDescriptor, McpToolResponse},
|
||||
types::{ChatRequest, ChatResponse, ModelInfo},
|
||||
};
|
||||
|
||||
/// Object-safe facade for interacting with LLM backends.
|
||||
#[async_trait]
|
||||
pub trait LlmClient: Send + Sync {
|
||||
/// List the models exposed by this client.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Issue a one-shot chat request and wait for the complete response.
|
||||
async fn send_chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
/// Stream chat responses incrementally.
|
||||
async fn stream_chat(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
/// Enumerate tools exposed by the backing provider.
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>>;
|
||||
|
||||
/// Invoke a tool exposed by the provider.
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||
}
|
||||
|
||||
/// Convenience alias for trait-object clients.
|
||||
pub type DynLlmClient = Arc<dyn LlmClient>;
|
||||
1
crates/owlen-core/src/facade/mod.rs
Normal file
1
crates/owlen-core/src/facade/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod llm_client;
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once we can rely on Rust 2024 let-chains
|
||||
|
||||
//! Core traits and types for OWLEN LLM client
|
||||
//!
|
||||
//! This crate provides the foundational abstractions for building
|
||||
@@ -9,12 +11,15 @@ pub mod consent;
|
||||
pub mod conversation;
|
||||
pub mod credentials;
|
||||
pub mod encryption;
|
||||
pub mod facade;
|
||||
pub mod formatting;
|
||||
pub mod input;
|
||||
pub mod llm;
|
||||
pub mod mcp;
|
||||
pub mod mode;
|
||||
pub mod model;
|
||||
pub mod oauth;
|
||||
pub mod provider;
|
||||
pub mod providers;
|
||||
pub mod router;
|
||||
pub mod sandbox;
|
||||
@@ -36,7 +41,9 @@ pub use credentials::*;
|
||||
pub use encryption::*;
|
||||
pub use formatting::*;
|
||||
pub use input::*;
|
||||
pub use oauth::*;
|
||||
// Export MCP types but exclude test_utils to avoid ambiguity
|
||||
pub use facade::llm_client::*;
|
||||
pub use llm::{
|
||||
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,
|
||||
};
|
||||
@@ -46,6 +53,7 @@ pub use mcp::{
|
||||
};
|
||||
pub use mode::*;
|
||||
pub use model::*;
|
||||
pub use provider::*;
|
||||
pub use providers::*;
|
||||
pub use router::*;
|
||||
pub use sandbox::*;
|
||||
|
||||
@@ -144,17 +144,57 @@ where
|
||||
/// Runtime configuration for a provider instance.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
/// Provider type identifier.
|
||||
/// Whether this provider should be activated.
|
||||
#[serde(default = "ProviderConfig::default_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Provider type identifier used to resolve implementations.
|
||||
#[serde(default)]
|
||||
pub provider_type: String,
|
||||
/// Base URL for API calls.
|
||||
#[serde(default)]
|
||||
pub base_url: Option<String>,
|
||||
/// API key or token material.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Environment variable holding the API key.
|
||||
#[serde(default)]
|
||||
pub api_key_env: Option<String>,
|
||||
/// Additional provider-specific configuration.
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderConfig {
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Merge the current configuration with overrides from `other`.
|
||||
pub fn merge_from(&mut self, mut other: ProviderConfig) {
|
||||
self.enabled = other.enabled;
|
||||
|
||||
if !other.provider_type.is_empty() {
|
||||
self.provider_type = other.provider_type;
|
||||
}
|
||||
|
||||
if let Some(base_url) = other.base_url.take() {
|
||||
self.base_url = Some(base_url);
|
||||
}
|
||||
|
||||
if let Some(api_key) = other.api_key.take() {
|
||||
self.api_key = Some(api_key);
|
||||
}
|
||||
|
||||
if let Some(api_key_env) = other.api_key_env.take() {
|
||||
self.api_key_env = Some(api_key_env);
|
||||
}
|
||||
|
||||
if !other.extra.is_empty() {
|
||||
self.extra.extend(other.extra);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Static registry of providers available to the application.
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Arc<dyn Provider>>,
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
/// Provides a unified interface for creating MCP clients based on configuration.
|
||||
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
||||
use super::client::McpClient;
|
||||
use super::{LocalMcpClient, remote_client::RemoteMcpClient};
|
||||
use super::{
|
||||
LocalMcpClient,
|
||||
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
|
||||
};
|
||||
use crate::config::{Config, McpMode};
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
@@ -33,6 +36,14 @@ impl McpClientFactory {
|
||||
|
||||
/// Create an MCP client based on the current configuration.
|
||||
pub fn create(&self) -> Result<Box<dyn McpClient>> {
|
||||
self.create_with_secrets(None)
|
||||
}
|
||||
|
||||
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
|
||||
pub fn create_with_secrets(
|
||||
&self,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Box<dyn McpClient>> {
|
||||
match self.config.mcp.mode {
|
||||
McpMode::Disabled => Err(Error::Config(
|
||||
"MCP mode is set to 'disabled'; tooling cannot function in this configuration."
|
||||
@@ -48,14 +59,14 @@ impl McpClientFactory {
|
||||
)))
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let server_cfg = self.config.mcp_servers.first().ok_or_else(|| {
|
||||
let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
RemoteMcpClient::new_with_config(server_cfg)
|
||||
RemoteMcpClient::new_with_runtime(server_cfg, runtime)
|
||||
.map(|client| Box::new(client) as Box<dyn McpClient>)
|
||||
.map_err(|e| {
|
||||
Error::Config(format!(
|
||||
@@ -65,8 +76,8 @@ impl McpClientFactory {
|
||||
})
|
||||
}
|
||||
McpMode::RemotePreferred => {
|
||||
if let Some(server_cfg) = self.config.mcp_servers.first() {
|
||||
match RemoteMcpClient::new_with_config(server_cfg) {
|
||||
if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()) {
|
||||
Ok(client) => {
|
||||
info!(
|
||||
"Connected to remote MCP server '{}' via {} transport.",
|
||||
@@ -125,7 +136,8 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_factory_creates_local_client_when_no_servers_configured() {
|
||||
let config = Config::default();
|
||||
let mut config = Config::default();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
|
||||
@@ -139,6 +151,7 @@ mod tests {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::RemoteOnly;
|
||||
config.mcp_servers.clear();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
@@ -156,7 +169,9 @@ mod tests {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
}];
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create();
|
||||
|
||||
@@ -305,6 +305,7 @@ mod tests {
|
||||
args: vec![],
|
||||
transport: "http".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
if let Ok(client) = RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
@@ -156,13 +156,14 @@ mod tests {
|
||||
use super::*;
|
||||
use crate::mcp::LocalMcpClient;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::ui::NoOpUiController;
|
||||
use crate::validation::SchemaValidator;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permission_layer_filters_dangerous_tools() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
@@ -186,7 +187,7 @@ mod tests {
|
||||
#[tokio::test]
|
||||
async fn test_consent_callback_is_invoked() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
|
||||
@@ -7,11 +7,15 @@ use crate::consent::{ConsentManager, ConsentScope};
|
||||
use crate::tools::{Tool, WebScrapeTool, WebSearchTool};
|
||||
use crate::types::ModelInfo;
|
||||
use crate::types::{ChatResponse, Message, Role};
|
||||
use crate::{Error, LlmProvider, Result, mode::Mode, send_via_stream};
|
||||
use crate::{
|
||||
ChatStream, Error, LlmProvider, Result, facade::llm_client::LlmClient, mode::Mode,
|
||||
send_via_stream,
|
||||
};
|
||||
use anyhow::anyhow;
|
||||
use futures::{StreamExt, future::BoxFuture, stream};
|
||||
use reqwest::Client as HttpClient;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicU64, Ordering};
|
||||
@@ -39,6 +43,15 @@ pub struct RemoteMcpClient {
|
||||
ws_endpoint: Option<String>,
|
||||
// Incrementing request identifier.
|
||||
next_id: AtomicU64,
|
||||
// Optional HTTP header (name, value) injected into every request.
|
||||
http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
/// Runtime secrets provided when constructing an MCP client.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct McpRuntimeSecrets {
|
||||
pub env_overrides: HashMap<String, String>,
|
||||
pub http_header: Option<(String, String)>,
|
||||
}
|
||||
|
||||
impl RemoteMcpClient {
|
||||
@@ -48,6 +61,14 @@ impl RemoteMcpClient {
|
||||
/// Spawn an external MCP server based on a configuration entry.
|
||||
/// The server must communicate over STDIO (the only supported transport).
|
||||
pub fn new_with_config(config: &crate::config::McpServerConfig) -> Result<Self> {
|
||||
Self::new_with_runtime(config, None)
|
||||
}
|
||||
|
||||
pub fn new_with_runtime(
|
||||
config: &crate::config::McpServerConfig,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Self> {
|
||||
let mut runtime = runtime.unwrap_or_default();
|
||||
let transport = config.transport.to_lowercase();
|
||||
match transport.as_str() {
|
||||
"stdio" => {
|
||||
@@ -64,6 +85,9 @@ impl RemoteMcpClient {
|
||||
for (k, v) in config.env.iter() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
for (k, v) in runtime.env_overrides.drain() {
|
||||
cmd.env(k, v);
|
||||
}
|
||||
|
||||
let mut child = cmd.spawn().map_err(|e| {
|
||||
Error::Io(std::io::Error::new(
|
||||
@@ -92,6 +116,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: None,
|
||||
})
|
||||
}
|
||||
"http" => {
|
||||
@@ -109,6 +134,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: None,
|
||||
ws_endpoint: None,
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
"websocket" => {
|
||||
@@ -132,6 +158,7 @@ impl RemoteMcpClient {
|
||||
ws_stream: Some(Arc::new(Mutex::new(ws_stream))),
|
||||
ws_endpoint: Some(ws_url),
|
||||
next_id: AtomicU64::new(1),
|
||||
http_header: runtime.http_header.take(),
|
||||
})
|
||||
}
|
||||
other => Err(Error::NotImplemented(format!(
|
||||
@@ -171,6 +198,7 @@ impl RemoteMcpClient {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
Self::new_with_config(&config)
|
||||
}
|
||||
@@ -193,8 +221,11 @@ impl RemoteMcpClient {
|
||||
.http_endpoint
|
||||
.as_ref()
|
||||
.ok_or_else(|| Error::Network("Missing HTTP endpoint".into()))?;
|
||||
let resp = client
|
||||
.post(endpoint)
|
||||
let mut builder = client.post(endpoint);
|
||||
if let Some((ref header_name, ref header_value)) = self.http_header {
|
||||
builder = builder.header(header_name, header_value);
|
||||
}
|
||||
let resp = builder
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
@@ -536,3 +567,27 @@ impl LlmProvider for RemoteMcpClient {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl LlmClient for RemoteMcpClient {
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
<Self as LlmProvider>::list_models(self).await
|
||||
}
|
||||
|
||||
async fn send_chat(&self, request: crate::types::ChatRequest) -> Result<ChatResponse> {
|
||||
<Self as LlmProvider>::send_prompt(self, request).await
|
||||
}
|
||||
|
||||
async fn stream_chat(&self, request: crate::types::ChatRequest) -> Result<ChatStream> {
|
||||
let stream = <Self as LlmProvider>::stream_prompt(self, request).await?;
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
<Self as McpClient>::list_tools(self).await
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
<Self as McpClient>::call_tool(self, call).await
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ impl ModelManager {
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<Vec<ModelInfo>>>,
|
||||
{
|
||||
if !force_refresh && let Some(models) = self.cached_if_fresh().await {
|
||||
if let (false, Some(models)) = (force_refresh, self.cached_if_fresh().await) {
|
||||
return Ok(models);
|
||||
}
|
||||
|
||||
|
||||
507
crates/owlen-core/src/oauth.rs
Normal file
507
crates/owlen-core/src/oauth.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, config::McpOAuthConfig};
|
||||
|
||||
/// Persisted OAuth token set for MCP servers and providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct OAuthToken {
|
||||
/// Bearer access token returned by the authorization server.
|
||||
pub access_token: String,
|
||||
/// Optional refresh token if the provider issues one.
|
||||
#[serde(default)]
|
||||
pub refresh_token: Option<String>,
|
||||
/// Absolute UTC expiration timestamp for the access token.
|
||||
#[serde(default)]
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
/// Optional space-delimited scope string supplied by the provider.
|
||||
#[serde(default)]
|
||||
pub scope: Option<String>,
|
||||
/// Token type reported by the provider (typically `Bearer`).
|
||||
#[serde(default)]
|
||||
pub token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl OAuthToken {
|
||||
/// Returns `true` if the access token has expired at the provided instant.
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
||||
}
|
||||
|
||||
/// Returns `true` if the token will expire within the supplied duration window.
|
||||
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
||||
}
|
||||
}
|
||||
|
||||
/// Active device-authorization session details returned by the authorization server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceAuthorization {
|
||||
pub device_code: String,
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
pub verification_uri_complete: Option<String>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub interval: StdDuration,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl DeviceAuthorization {
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of polling the token endpoint during a device-authorization flow.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DevicePollState {
|
||||
Pending { retry_in: StdDuration },
|
||||
Complete(OAuthToken),
|
||||
}
|
||||
|
||||
pub struct OAuthClient {
|
||||
http: Client,
|
||||
config: McpOAuthConfig,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
||||
let http = Client::builder()
|
||||
.user_agent("OwlenOAuth/1.0")
|
||||
.build()
|
||||
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
||||
Ok(Self { http, config })
|
||||
}
|
||||
|
||||
fn scope_value(&self) -> Option<String> {
|
||||
if self.config.scopes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.config.scopes.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
fn token_request_base(&self) -> Vec<(String, String)> {
|
||||
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
||||
if let Some(secret) = &self.config.client_secret {
|
||||
params.push(("client_secret".to_string(), secret.clone()));
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
||||
let device_url = self
|
||||
.config
|
||||
.device_authorization_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
Error::Config("Device authorization endpoint is not configured.".to_string())
|
||||
})?;
|
||||
|
||||
let mut params = self.token_request_base();
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(device_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("start device authorization", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let payload = response
|
||||
.json::<DeviceAuthorizationResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse device authorization response (status {status}): {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let expires_at =
|
||||
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
||||
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
||||
|
||||
Ok(DeviceAuthorization {
|
||||
device_code: payload.device_code,
|
||||
user_code: payload.user_code,
|
||||
verification_uri: payload.verification_uri,
|
||||
verification_uri_complete: payload.verification_uri_complete,
|
||||
expires_at,
|
||||
interval,
|
||||
message: payload.message,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
||||
params.push(("device_code".to_string(), auth.device_code.clone()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("poll device token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read token response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth token response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
||||
payload,
|
||||
)));
|
||||
}
|
||||
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
|
||||
match error.error.as_str() {
|
||||
"authorization_pending" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval,
|
||||
}),
|
||||
"slow_down" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
||||
}),
|
||||
"access_denied" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"User declined authorization".to_string()
|
||||
})))
|
||||
}
|
||||
"expired_token" | "expired_device_code" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"Device authorization expired".to_string()
|
||||
})))
|
||||
}
|
||||
other => Err(Error::Auth(
|
||||
error
|
||||
.error_description
|
||||
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
||||
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read refresh response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
Ok(oauth_token_from_response(payload))
|
||||
} else {
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
format!("OAuth token refresh failed: {}", error.error)
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceAuthorizationResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default)]
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
#[serde(default)]
|
||||
interval: Option<u64>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
scope: Option<String>,
|
||||
#[serde(default)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthErrorResponse {
|
||||
error: String,
|
||||
#[serde(default)]
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
||||
let expires_at = payload
|
||||
.expires_in
|
||||
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
||||
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
||||
|
||||
OAuthToken {
|
||||
access_token: payload.access_token,
|
||||
refresh_token: payload.refresh_token,
|
||||
expires_at,
|
||||
scope: payload.scope,
|
||||
token_type: payload.token_type,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
||||
if err.is_timeout() {
|
||||
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
||||
} else if err.is_connect() {
|
||||
Error::Network(format!("OAuth {action} connection error: {err}"))
|
||||
} else {
|
||||
Error::Network(format!("OAuth {action} request failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "test-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string(), "user".to_string()],
|
||||
token_env: None,
|
||||
header: None,
|
||||
header_prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_device_authorization() -> DeviceAuthorization {
|
||||
DeviceAuthorization {
|
||||
device_code: "device-123".to_string(),
|
||||
user_code: "ABCD-EFGH".to_string(),
|
||||
verification_uri: "https://example.test/activate".to_string(),
|
||||
verification_uri_complete: Some(
|
||||
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
||||
),
|
||||
expires_at: Utc::now() + Duration::minutes(10),
|
||||
interval: StdDuration::from_secs(5),
|
||||
message: Some("Open the verification URL and enter the code.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_device_authorization_returns_payload() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device_mock = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-123",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
||||
"expires_in": 600,
|
||||
"interval": 7,
|
||||
"message": "Open the verification URL and enter the code."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = OAuthClient::new(config_for(&server)).expect("client");
|
||||
let auth = client
|
||||
.start_device_authorization()
|
||||
.await
|
||||
.expect("device authorization payload");
|
||||
|
||||
assert_eq!(auth.user_code, "ABCD-EFGH");
|
||||
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
||||
assert!(auth.expires_at > Utc::now());
|
||||
device_mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_reports_pending() {
|
||||
let server = MockServer::start_async().await;
|
||||
let pending = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains(
|
||||
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
||||
)
|
||||
.body_contains("device_code=device-123");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "authorization_pending"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(5));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
pending.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_applies_slow_down_backoff() {
|
||||
let server = MockServer::start_async().await;
|
||||
let slow = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "slow_down"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(10));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
slow.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_returns_token_when_authorized() {
|
||||
let server = MockServer::start_async().await;
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-abc",
|
||||
"refresh_token": "refresh-xyz",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "repo user"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
let token_info = match result {
|
||||
DevicePollState::Complete(token) => token,
|
||||
other => panic!("expected completion, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_eq!(token_info.access_token, "token-abc");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
||||
assert!(token_info.expires_at.is_some());
|
||||
token.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_token_roundtrip() {
|
||||
let server = MockServer::start_async().await;
|
||||
let refresh = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("grant_type=refresh_token")
|
||||
.body_contains("refresh_token=old-refresh");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-new",
|
||||
"refresh_token": "refresh-new",
|
||||
"expires_in": 1200,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let token = client
|
||||
.refresh_token("old-refresh")
|
||||
.await
|
||||
.expect("refresh response");
|
||||
|
||||
assert_eq!(token.access_token, "token-new");
|
||||
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
||||
assert!(token.expires_at.is_some());
|
||||
refresh.assert_async().await;
|
||||
}
|
||||
}
|
||||
227
crates/owlen-core/src/provider/manager.rs
Normal file
227
crates/owlen-core/src/provider/manager.rs
Normal file
@@ -0,0 +1,227 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use log::{debug, warn};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use super::{GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus};
|
||||
|
||||
/// Model information annotated with the originating provider metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnotatedModelInfo {
|
||||
pub provider_id: String,
|
||||
pub provider_status: ProviderStatus,
|
||||
pub model: ModelInfo,
|
||||
}
|
||||
|
||||
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||
/// health state.
|
||||
pub struct ProviderManager {
|
||||
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
/// Construct a new manager using the supplied configuration. Providers
|
||||
/// defined in the configuration start with a `RequiresSetup` status so
|
||||
/// that frontends can surface incomplete configuration to users.
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let mut status_cache = HashMap::new();
|
||||
for provider_id in config.providers.keys() {
|
||||
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||
}
|
||||
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(status_cache),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider instance with the manager.
|
||||
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||
let provider_id = provider.metadata().id.clone();
|
||||
debug!("registering provider {}", provider_id);
|
||||
|
||||
self.providers
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.clone(), provider);
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id, ProviderStatus::Unavailable);
|
||||
}
|
||||
|
||||
/// Return a stream by routing the request to the designated provider.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStream> {
|
||||
let provider = {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||
|
||||
match provider.generate_stream(request).await {
|
||||
Ok(stream) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List models across all providers, updating provider status along the way.
|
||||
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let log_id = provider_id.clone();
|
||||
let mut status = ProviderStatus::Unavailable;
|
||||
let mut models = Vec::new();
|
||||
|
||||
match provider.health_check().await {
|
||||
Ok(health) => {
|
||||
status = health;
|
||||
if matches!(status, ProviderStatus::Available) {
|
||||
match provider.list_models().await {
|
||||
Ok(list) => {
|
||||
models = list;
|
||||
}
|
||||
Err(err) => {
|
||||
status = ProviderStatus::Unavailable;
|
||||
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
|
||||
(provider_id, status, models)
|
||||
});
|
||||
}
|
||||
|
||||
let mut annotated = Vec::new();
|
||||
let mut status_updates = HashMap::new();
|
||||
|
||||
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||
status_updates.insert(provider_id.clone(), status);
|
||||
for model in models {
|
||||
annotated.push(AnnotatedModelInfo {
|
||||
provider_id: provider_id.clone(),
|
||||
provider_status: status,
|
||||
model,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in status_updates {
|
||||
guard.insert(provider_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(annotated)
|
||||
}
|
||||
|
||||
/// Refresh the health of all registered providers in parallel, returning
|
||||
/// the latest status snapshot.
|
||||
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let status = match provider.health_check().await {
|
||||
Ok(status) => status,
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||
ProviderStatus::Unavailable
|
||||
}
|
||||
};
|
||||
(provider_id, status)
|
||||
});
|
||||
}
|
||||
|
||||
let mut updates = HashMap::new();
|
||||
while let Some((provider_id, status)) = tasks.next().await {
|
||||
updates.insert(provider_id, status);
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in &updates {
|
||||
guard.insert(provider_id.clone(), *status);
|
||||
}
|
||||
}
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
/// Return the provider instance for an identifier.
|
||||
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
|
||||
/// List the registered provider identifiers.
|
||||
pub async fn provider_ids(&self) -> Vec<String> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Retrieve the last known status for a provider.
|
||||
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.get(provider_id).copied()
|
||||
}
|
||||
|
||||
/// Snapshot the currently cached statuses.
|
||||
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
36
crates/owlen-core/src/provider/mod.rs
Normal file
36
crates/owlen-core/src/provider/mod.rs
Normal file
@@ -0,0 +1,36 @@
|
||||
//! Unified provider abstraction layer.
|
||||
//!
|
||||
//! This module defines the async [`ModelProvider`] trait that all model
|
||||
//! backends implement, together with a small suite of shared data structures
|
||||
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||
//! orchestrates multiple providers and coordinates their health state.
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
pub use self::{manager::*, types::*};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// Convenience alias for the stream type yielded by [`ModelProvider::generate_stream`].
|
||||
pub type GenerateStream = Pin<Box<dyn Stream<Item = Result<GenerateChunk>> + Send + 'static>>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ModelProvider: Send + Sync {
|
||||
/// Returns descriptive metadata about the provider.
|
||||
fn metadata(&self) -> &ProviderMetadata;
|
||||
|
||||
/// Check the current health state for the provider.
|
||||
async fn health_check(&self) -> Result<ProviderStatus>;
|
||||
|
||||
/// List all models available through the provider.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Acquire a streaming response for a generation request.
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> Result<GenerateStream>;
|
||||
}
|
||||
124
crates/owlen-core/src/provider/types.rs
Normal file
124
crates/owlen-core/src/provider/types.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
//! Shared types used by the unified provider abstraction layer.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Categorises providers so the UI can distinguish between local and hosted
|
||||
/// backends.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderType {
|
||||
Local,
|
||||
Cloud,
|
||||
}
|
||||
|
||||
/// Represents the current availability state for a provider.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderStatus {
|
||||
Available,
|
||||
Unavailable,
|
||||
RequiresSetup,
|
||||
}
|
||||
|
||||
/// Describes core metadata for a provider implementation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ProviderMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub provider_type: ProviderType,
|
||||
pub requires_auth: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata {
|
||||
/// Construct a new metadata instance for a provider.
|
||||
pub fn new(
|
||||
id: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
requires_auth: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
provider_type,
|
||||
requires_auth,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a model that can be displayed to users.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub size_bytes: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub provider: ProviderMetadata,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Unified request for streaming text generation across providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateRequest {
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
pub context: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub parameters: HashMap<String, Value>,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
/// Helper for building a request from the minimum required fields.
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
prompt: None,
|
||||
context: Vec::new(),
|
||||
parameters: HashMap::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streamed chunk of generation output from a model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateChunk {
|
||||
#[serde(default)]
|
||||
pub text: Option<String>,
|
||||
#[serde(default)]
|
||||
pub is_final: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateChunk {
|
||||
/// Construct a new chunk with the provided text payload.
|
||||
pub fn from_text(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: Some(text.into()),
|
||||
is_final: false,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark the chunk as the terminal item in a stream.
|
||||
pub fn final_chunk() -> Self {
|
||||
Self {
|
||||
text: None,
|
||||
is_final: true,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
//! Ollama provider built on top of the `ollama-rs` crate.
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
collections::{HashMap, HashSet},
|
||||
env,
|
||||
net::{SocketAddr, TcpStream},
|
||||
pin::Pin,
|
||||
time::{Duration, SystemTime},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
use anyhow::anyhow;
|
||||
@@ -22,11 +24,17 @@ use ollama_rs::{
|
||||
};
|
||||
use reqwest::{Client, StatusCode, Url};
|
||||
use serde_json::{Map as JsonMap, Value, json};
|
||||
use tokio::{sync::RwLock, time::timeout};
|
||||
|
||||
#[cfg(test)]
|
||||
use std::sync::{Mutex, OnceLock};
|
||||
#[cfg(test)]
|
||||
use tokio_test::block_on;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
Error, Result,
|
||||
config::GeneralSettings,
|
||||
config::{GeneralSettings, OLLAMA_CLOUD_BASE_URL, OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY},
|
||||
llm::{LlmProvider, ProviderConfig},
|
||||
mcp::McpToolDescriptor,
|
||||
model::{DetailedModelInfo, ModelDetailsCache, ModelManager},
|
||||
@@ -37,9 +45,11 @@ use crate::{
|
||||
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
||||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
||||
const CLOUD_BASE_URL: &str = "https://ollama.com";
|
||||
pub(crate) const CLOUD_BASE_URL: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const LOCAL_PROBE_TIMEOUT_MS: u64 = 200;
|
||||
const LOCAL_PROBE_TARGETS: &[&str] = &["127.0.0.1:11434", "[::1]:11434"];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum OllamaMode {
|
||||
Local,
|
||||
Cloud,
|
||||
@@ -54,6 +64,64 @@ impl OllamaMode {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ScopeAvailability {
|
||||
Unknown,
|
||||
Available,
|
||||
Unavailable,
|
||||
}
|
||||
|
||||
impl ScopeAvailability {
|
||||
fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
ScopeAvailability::Unknown => "unknown",
|
||||
ScopeAvailability::Available => "available",
|
||||
ScopeAvailability::Unavailable => "unavailable",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ScopeSnapshot {
|
||||
models: Vec<ModelInfo>,
|
||||
fetched_at: Option<Instant>,
|
||||
availability: ScopeAvailability,
|
||||
last_error: Option<String>,
|
||||
last_checked: Option<Instant>,
|
||||
last_success_at: Option<Instant>,
|
||||
}
|
||||
|
||||
impl Default for ScopeSnapshot {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
models: Vec::new(),
|
||||
fetched_at: None,
|
||||
availability: ScopeAvailability::Unknown,
|
||||
last_error: None,
|
||||
last_checked: None,
|
||||
last_success_at: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ScopeSnapshot {
|
||||
fn is_stale(&self, ttl: Duration) -> bool {
|
||||
match self.fetched_at {
|
||||
Some(ts) => ts.elapsed() >= ttl,
|
||||
None => !self.models.is_empty(),
|
||||
}
|
||||
}
|
||||
|
||||
fn last_checked_age_secs(&self) -> Option<u64> {
|
||||
self.last_checked.map(|instant| instant.elapsed().as_secs())
|
||||
}
|
||||
|
||||
fn last_success_age_secs(&self) -> Option<u64> {
|
||||
self.last_success_at
|
||||
.map(|instant| instant.elapsed().as_secs())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct OllamaOptions {
|
||||
mode: OllamaMode,
|
||||
@@ -61,6 +129,7 @@ struct OllamaOptions {
|
||||
request_timeout: Duration,
|
||||
model_cache_ttl: Duration,
|
||||
api_key: Option<String>,
|
||||
cloud_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
impl OllamaOptions {
|
||||
@@ -71,6 +140,7 @@ impl OllamaOptions {
|
||||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
||||
api_key: None,
|
||||
cloud_endpoint: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -87,8 +157,78 @@ pub struct OllamaProvider {
|
||||
client: Ollama,
|
||||
http_client: Client,
|
||||
base_url: String,
|
||||
request_timeout: Duration,
|
||||
api_key: Option<String>,
|
||||
cloud_endpoint: Option<String>,
|
||||
model_manager: ModelManager,
|
||||
model_details_cache: ModelDetailsCache,
|
||||
model_cache_ttl: Duration,
|
||||
scope_cache: Arc<RwLock<HashMap<OllamaMode, ScopeSnapshot>>>,
|
||||
}
|
||||
|
||||
fn configured_mode_from_extra(config: &ProviderConfig) -> Option<OllamaMode> {
|
||||
config
|
||||
.extra
|
||||
.get(OLLAMA_MODE_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.and_then(|value| match value.trim().to_ascii_lowercase().as_str() {
|
||||
"local" => Some(OllamaMode::Local),
|
||||
"cloud" => Some(OllamaMode::Cloud),
|
||||
_ => None,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_explicit_local_base(base_url: Option<&str>) -> bool {
|
||||
base_url
|
||||
.and_then(|raw| Url::parse(raw).ok())
|
||||
.and_then(|parsed| parsed.host_str().map(|host| host.to_ascii_lowercase()))
|
||||
.map(|host| host == "localhost" || host == "127.0.0.1" || host == "::1")
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn is_explicit_cloud_base(base_url: Option<&str>) -> bool {
|
||||
base_url
|
||||
.map(|raw| {
|
||||
let trimmed = raw.trim_end_matches('/');
|
||||
trimmed == CLOUD_BASE_URL || trimmed.starts_with("https://ollama.com/")
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
static PROBE_OVERRIDE: OnceLock<Mutex<Option<bool>>> = OnceLock::new();
|
||||
|
||||
#[cfg(test)]
|
||||
fn set_probe_override(value: Option<bool>) {
|
||||
let guard = PROBE_OVERRIDE.get_or_init(|| Mutex::new(None));
|
||||
*guard.lock().expect("probe override mutex poisoned") = value;
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn probe_override_value() -> Option<bool> {
|
||||
PROBE_OVERRIDE
|
||||
.get_or_init(|| Mutex::new(None))
|
||||
.lock()
|
||||
.expect("probe override mutex poisoned")
|
||||
.to_owned()
|
||||
}
|
||||
|
||||
fn probe_default_local_daemon(timeout: Duration) -> bool {
|
||||
#[cfg(test)]
|
||||
{
|
||||
if let Some(value) = probe_override_value() {
|
||||
return value;
|
||||
}
|
||||
}
|
||||
|
||||
for target in LOCAL_PROBE_TARGETS {
|
||||
if let Ok(address) = target.parse::<SocketAddr>() {
|
||||
if TcpStream::connect_timeout(&address, timeout).is_ok() {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
impl OllamaProvider {
|
||||
@@ -105,23 +245,64 @@ impl OllamaProvider {
|
||||
let mut api_key = resolve_api_key(config.api_key.clone())
|
||||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
||||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
||||
let configured_mode = configured_mode_from_extra(config);
|
||||
let configured_mode_label = config
|
||||
.extra
|
||||
.get(OLLAMA_MODE_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.unwrap_or("auto");
|
||||
let base_url = config.base_url.as_deref();
|
||||
let base_is_local = is_explicit_local_base(base_url);
|
||||
let base_is_cloud = is_explicit_cloud_base(base_url);
|
||||
let base_is_other = base_url.is_some() && !base_is_local && !base_is_cloud;
|
||||
|
||||
let mode = if api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
OllamaMode::Local
|
||||
let mut local_probe_result = None;
|
||||
let cloud_endpoint = config
|
||||
.extra
|
||||
.get(OLLAMA_CLOUD_ENDPOINT_KEY)
|
||||
.and_then(Value::as_str)
|
||||
.map(normalize_cloud_endpoint)
|
||||
.transpose()
|
||||
.map_err(Error::Config)?;
|
||||
|
||||
let mode = match configured_mode {
|
||||
Some(mode) => mode,
|
||||
None => {
|
||||
if base_is_local || base_is_other {
|
||||
OllamaMode::Local
|
||||
} else if base_is_cloud && api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
let probe =
|
||||
probe_default_local_daemon(Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS));
|
||||
local_probe_result = Some(probe);
|
||||
if probe {
|
||||
OllamaMode::Local
|
||||
} else if api_key.is_some() {
|
||||
OllamaMode::Cloud
|
||||
} else {
|
||||
OllamaMode::Local
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let base_candidate = if mode == OllamaMode::Cloud {
|
||||
Some(CLOUD_BASE_URL)
|
||||
} else {
|
||||
config.base_url.as_deref()
|
||||
let base_candidate = match mode {
|
||||
OllamaMode::Local => base_url,
|
||||
OllamaMode::Cloud => {
|
||||
if base_is_cloud {
|
||||
base_url
|
||||
} else {
|
||||
Some(CLOUD_BASE_URL)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let normalized_base_url =
|
||||
normalize_base_url(base_candidate, mode).map_err(Error::Config)?;
|
||||
|
||||
let mut options = OllamaOptions::new(mode, normalized_base_url);
|
||||
let mut options = OllamaOptions::new(mode, normalized_base_url.clone());
|
||||
options.cloud_endpoint = cloud_endpoint.clone();
|
||||
|
||||
if let Some(timeout) = config
|
||||
.extra
|
||||
@@ -145,6 +326,23 @@ impl OllamaProvider {
|
||||
options = options.with_general(general);
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Resolved Ollama provider: mode={:?}, base_url={}, configured_mode={}, api_key_present={}, local_probe={}",
|
||||
mode,
|
||||
normalized_base_url,
|
||||
configured_mode_label,
|
||||
if options.api_key.is_some() {
|
||||
"yes"
|
||||
} else {
|
||||
"no"
|
||||
},
|
||||
match local_probe_result {
|
||||
Some(true) => "success",
|
||||
Some(false) => "failed",
|
||||
None => "skipped",
|
||||
}
|
||||
);
|
||||
|
||||
Self::with_options(options)
|
||||
}
|
||||
|
||||
@@ -155,44 +353,32 @@ impl OllamaProvider {
|
||||
request_timeout,
|
||||
model_cache_ttl,
|
||||
api_key,
|
||||
cloud_endpoint,
|
||||
} = options;
|
||||
|
||||
let url = Url::parse(&base_url)
|
||||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||||
let api_key_ref = api_key.as_deref();
|
||||
let (ollama_client, http_client) =
|
||||
build_client_for_base(&base_url, request_timeout, api_key_ref)?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(ref key) = api_key {
|
||||
let value = HeaderValue::from_str(&format!("Bearer {key}")).map_err(|_| {
|
||||
Error::Config("OLLAMA API key contains invalid characters".to_string())
|
||||
})?;
|
||||
headers.insert(AUTHORIZATION, value);
|
||||
}
|
||||
|
||||
let mut client_builder = Client::builder().timeout(request_timeout);
|
||||
if !headers.is_empty() {
|
||||
client_builder = client_builder.default_headers(headers.clone());
|
||||
}
|
||||
|
||||
let http_client = client_builder
|
||||
.build()
|
||||
.map_err(|err| Error::Config(format!("Failed to build HTTP client: {err}")))?;
|
||||
|
||||
let port = url.port_or_known_default().ok_or_else(|| {
|
||||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||||
})?;
|
||||
|
||||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||||
if !headers.is_empty() {
|
||||
ollama_client.set_headers(Some(headers.clone()));
|
||||
}
|
||||
let scope_cache = {
|
||||
let mut initial = HashMap::new();
|
||||
initial.insert(OllamaMode::Local, ScopeSnapshot::default());
|
||||
initial.insert(OllamaMode::Cloud, ScopeSnapshot::default());
|
||||
Arc::new(RwLock::new(initial))
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
mode,
|
||||
client: ollama_client,
|
||||
http_client,
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
request_timeout,
|
||||
api_key,
|
||||
cloud_endpoint,
|
||||
model_manager: ModelManager::new(model_cache_ttl),
|
||||
model_details_cache: ModelDetailsCache::new(model_cache_ttl),
|
||||
model_cache_ttl,
|
||||
scope_cache,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -200,6 +386,167 @@ impl OllamaProvider {
|
||||
build_api_endpoint(&self.base_url, endpoint)
|
||||
}
|
||||
|
||||
fn local_base_url() -> &'static str {
|
||||
OllamaMode::Local.default_base_url()
|
||||
}
|
||||
|
||||
fn scope_key(scope: OllamaMode) -> &'static str {
|
||||
match scope {
|
||||
OllamaMode::Local => "local",
|
||||
OllamaMode::Cloud => "cloud",
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_client(&self) -> Result<Option<Ollama>> {
|
||||
if matches!(self.mode, OllamaMode::Local) {
|
||||
return Ok(Some(self.client.clone()));
|
||||
}
|
||||
|
||||
let (client, _) =
|
||||
build_client_for_base(Self::local_base_url(), self.request_timeout, None)?;
|
||||
Ok(Some(client))
|
||||
}
|
||||
|
||||
fn build_cloud_client(&self) -> Result<Option<Ollama>> {
|
||||
if matches!(self.mode, OllamaMode::Cloud) {
|
||||
return Ok(Some(self.client.clone()));
|
||||
}
|
||||
|
||||
let api_key = match self.api_key.as_deref() {
|
||||
Some(key) if !key.trim().is_empty() => key,
|
||||
_ => return Ok(None),
|
||||
};
|
||||
|
||||
let endpoint = self.cloud_endpoint.as_deref().unwrap_or(CLOUD_BASE_URL);
|
||||
|
||||
let (client, _) = build_client_for_base(endpoint, self.request_timeout, Some(api_key))?;
|
||||
Ok(Some(client))
|
||||
}
|
||||
|
||||
async fn cached_scope_models(&self, scope: OllamaMode) -> Option<Vec<ModelInfo>> {
|
||||
let cache = self.scope_cache.read().await;
|
||||
cache.get(&scope).and_then(|entry| {
|
||||
if entry.availability == ScopeAvailability::Unknown {
|
||||
return None;
|
||||
}
|
||||
|
||||
if entry.models.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(ts) = entry.fetched_at {
|
||||
if ts.elapsed() < self.model_cache_ttl {
|
||||
return Some(entry.models.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to last good models even if stale; UI will mark as degraded
|
||||
Some(entry.models.clone())
|
||||
})
|
||||
}
|
||||
|
||||
async fn update_scope_success(&self, scope: OllamaMode, models: &[ModelInfo]) {
|
||||
let mut cache = self.scope_cache.write().await;
|
||||
let entry = cache.entry(scope).or_default();
|
||||
let now = Instant::now();
|
||||
entry.models = models.to_vec();
|
||||
entry.fetched_at = Some(now);
|
||||
entry.last_checked = Some(now);
|
||||
entry.last_success_at = Some(now);
|
||||
entry.availability = ScopeAvailability::Available;
|
||||
entry.last_error = None;
|
||||
}
|
||||
|
||||
async fn mark_scope_failure(&self, scope: OllamaMode, message: String) {
|
||||
let mut cache = self.scope_cache.write().await;
|
||||
let entry = cache.entry(scope).or_default();
|
||||
entry.availability = ScopeAvailability::Unavailable;
|
||||
entry.last_error = Some(message);
|
||||
entry.last_checked = Some(Instant::now());
|
||||
}
|
||||
|
||||
async fn annotate_scope_status(&self, models: &mut [ModelInfo]) {
|
||||
if models.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let cache = self.scope_cache.read().await;
|
||||
for (scope, snapshot) in cache.iter() {
|
||||
if snapshot.availability == ScopeAvailability::Unknown {
|
||||
continue;
|
||||
}
|
||||
let scope_key = Self::scope_key(*scope);
|
||||
let capability = format!(
|
||||
"scope-status:{}:{}",
|
||||
scope_key,
|
||||
snapshot.availability.as_str()
|
||||
);
|
||||
|
||||
for model in models.iter_mut() {
|
||||
if !model.capabilities.iter().any(|cap| cap == &capability) {
|
||||
model.capabilities.push(capability.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let stale = snapshot.is_stale(self.model_cache_ttl);
|
||||
let stale_capability = format!(
|
||||
"scope-status-stale:{}:{}",
|
||||
scope_key,
|
||||
if stale { "1" } else { "0" }
|
||||
);
|
||||
for model in models.iter_mut() {
|
||||
if !model
|
||||
.capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == &stale_capability)
|
||||
{
|
||||
model.capabilities.push(stale_capability.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(age) = snapshot.last_checked_age_secs() {
|
||||
let age_capability = format!("scope-status-age:{}:{}", scope_key, age);
|
||||
for model in models.iter_mut() {
|
||||
if !model.capabilities.iter().any(|cap| cap == &age_capability) {
|
||||
model.capabilities.push(age_capability.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(success_age) = snapshot.last_success_age_secs() {
|
||||
let success_capability =
|
||||
format!("scope-status-success-age:{}:{}", scope_key, success_age);
|
||||
for model in models.iter_mut() {
|
||||
if !model
|
||||
.capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == &success_capability)
|
||||
{
|
||||
model.capabilities.push(success_capability.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(raw_reason) = snapshot.last_error.as_ref() {
|
||||
let cleaned = raw_reason.replace('\n', " ").trim().to_string();
|
||||
if !cleaned.is_empty() {
|
||||
let truncated: String = cleaned.chars().take(160).collect();
|
||||
let message_capability =
|
||||
format!("scope-status-message:{}:{}", scope_key, truncated);
|
||||
for model in models.iter_mut() {
|
||||
if !model
|
||||
.capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == &message_capability)
|
||||
{
|
||||
model.capabilities.push(message_capability.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to resolve detailed model information for the given model, using the local cache when possible.
|
||||
pub async fn get_model_info(&self, model_name: &str) -> Result<DetailedModelInfo> {
|
||||
if let Some(info) = self.model_details_cache.get(model_name).await {
|
||||
@@ -312,15 +659,92 @@ impl OllamaProvider {
|
||||
}
|
||||
|
||||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let models = self
|
||||
.client
|
||||
.list_local_models()
|
||||
.await
|
||||
.map_err(|err| self.map_ollama_error("list models", err, None))?;
|
||||
let mut combined = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
let mut errors: Vec<Error> = Vec::new();
|
||||
|
||||
if let Some(local_client) = self.build_local_client()? {
|
||||
match self
|
||||
.fetch_models_for_scope(OllamaMode::Local, local_client.clone())
|
||||
.await
|
||||
{
|
||||
Ok(models) => {
|
||||
for model in models {
|
||||
let key = format!("local::{}", model.id);
|
||||
if seen.insert(key) {
|
||||
combined.push(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => errors.push(err),
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cloud_client) = self.build_cloud_client()? {
|
||||
match self
|
||||
.fetch_models_for_scope(OllamaMode::Cloud, cloud_client.clone())
|
||||
.await
|
||||
{
|
||||
Ok(models) => {
|
||||
for model in models {
|
||||
let key = format!("cloud::{}", model.id);
|
||||
if seen.insert(key) {
|
||||
combined.push(model);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => errors.push(err),
|
||||
}
|
||||
}
|
||||
|
||||
if combined.is_empty() {
|
||||
if let Some(err) = errors.pop() {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
|
||||
self.annotate_scope_status(&mut combined).await;
|
||||
combined.sort_by(|a, b| a.name.to_lowercase().cmp(&b.name.to_lowercase()));
|
||||
Ok(combined)
|
||||
}
|
||||
|
||||
async fn fetch_models_for_scope(
|
||||
&self,
|
||||
scope: OllamaMode,
|
||||
client: Ollama,
|
||||
) -> Result<Vec<ModelInfo>> {
|
||||
let list_result = if matches!(scope, OllamaMode::Local) {
|
||||
match timeout(
|
||||
Duration::from_millis(LOCAL_PROBE_TIMEOUT_MS),
|
||||
client.list_local_models(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result.map_err(|err| self.map_ollama_error("list models", err, None)),
|
||||
Err(_) => Err(Error::Timeout(
|
||||
"Timed out while contacting the local Ollama daemon".to_string(),
|
||||
)),
|
||||
}
|
||||
} else {
|
||||
client
|
||||
.list_local_models()
|
||||
.await
|
||||
.map_err(|err| self.map_ollama_error("list models", err, None))
|
||||
};
|
||||
|
||||
let models = match list_result {
|
||||
Ok(models) => models,
|
||||
Err(err) => {
|
||||
let message = err.to_string();
|
||||
self.mark_scope_failure(scope, message).await;
|
||||
if let Some(cached) = self.cached_scope_models(scope).await {
|
||||
return Ok(cached);
|
||||
}
|
||||
return Err(err);
|
||||
}
|
||||
};
|
||||
|
||||
let client = self.client.clone();
|
||||
let cache = self.model_details_cache.clone();
|
||||
let mode = self.mode;
|
||||
let fetched = join_all(models.into_iter().map(|local| {
|
||||
let client = client.clone();
|
||||
let cache = cache.clone();
|
||||
@@ -329,7 +753,7 @@ impl OllamaProvider {
|
||||
let detail = match client.show_model_info(name.clone()).await {
|
||||
Ok(info) => {
|
||||
let detailed = OllamaProvider::convert_detailed_model_info(
|
||||
mode,
|
||||
scope,
|
||||
&name,
|
||||
Some(&local),
|
||||
&info,
|
||||
@@ -347,10 +771,13 @@ impl OllamaProvider {
|
||||
}))
|
||||
.await;
|
||||
|
||||
Ok(fetched
|
||||
let converted: Vec<ModelInfo> = fetched
|
||||
.into_iter()
|
||||
.map(|(local, detail)| self.convert_model(local, detail))
|
||||
.collect())
|
||||
.map(|(local, detail)| self.convert_model(scope, local, detail))
|
||||
.collect();
|
||||
|
||||
self.update_scope_success(scope, &converted).await;
|
||||
Ok(converted)
|
||||
}
|
||||
|
||||
fn convert_detailed_model_info(
|
||||
@@ -378,10 +805,8 @@ impl OllamaProvider {
|
||||
let family = pick_first_string(map, &["family", "model_family"]);
|
||||
let mut families = pick_string_list(map, &["families", "model_families"]);
|
||||
|
||||
if families.is_empty()
|
||||
&& let Some(single) = family.clone()
|
||||
{
|
||||
families.push(single);
|
||||
if families.is_empty() {
|
||||
families.extend(family.clone());
|
||||
}
|
||||
|
||||
let system = pick_first_string(map, &["system"]);
|
||||
@@ -432,8 +857,13 @@ impl OllamaProvider {
|
||||
info.with_normalised_strings()
|
||||
}
|
||||
|
||||
fn convert_model(&self, model: LocalModel, detail: Option<OllamaModelInfo>) -> ModelInfo {
|
||||
let scope = match self.mode {
|
||||
fn convert_model(
|
||||
&self,
|
||||
scope: OllamaMode,
|
||||
model: LocalModel,
|
||||
detail: Option<OllamaModelInfo>,
|
||||
) -> ModelInfo {
|
||||
let scope_tag = match scope {
|
||||
OllamaMode::Local => "local",
|
||||
OllamaMode::Cloud => "cloud",
|
||||
};
|
||||
@@ -455,7 +885,9 @@ impl OllamaProvider {
|
||||
push_capability(&mut capabilities, &heuristic);
|
||||
}
|
||||
|
||||
let description = build_model_description(scope, detail.as_ref());
|
||||
push_capability(&mut capabilities, &format!("scope:{scope_tag}"));
|
||||
|
||||
let description = build_model_description(scope_tag, detail.as_ref());
|
||||
|
||||
ModelInfo {
|
||||
id: name.clone(),
|
||||
@@ -1006,6 +1438,10 @@ fn normalize_base_url(
|
||||
Ok(url.to_string().trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
fn normalize_cloud_endpoint(input: &str) -> std::result::Result<String, String> {
|
||||
normalize_base_url(Some(input), OllamaMode::Cloud)
|
||||
}
|
||||
|
||||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||||
let trimmed_base = base_url.trim_end_matches('/');
|
||||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
||||
@@ -1017,9 +1453,48 @@ fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_client_for_base(
|
||||
base_url: &str,
|
||||
timeout: Duration,
|
||||
api_key: Option<&str>,
|
||||
) -> Result<(Ollama, Client)> {
|
||||
let url = Url::parse(base_url)
|
||||
.map_err(|err| Error::Config(format!("Invalid Ollama base URL '{base_url}': {err}")))?;
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
if let Some(key) = api_key {
|
||||
let value = HeaderValue::from_str(&format!("Bearer {key}"))
|
||||
.map_err(|_| Error::Config("OLLAMA API key contains invalid characters".to_string()))?;
|
||||
headers.insert(AUTHORIZATION, value);
|
||||
}
|
||||
|
||||
let mut client_builder = Client::builder().timeout(timeout);
|
||||
if !headers.is_empty() {
|
||||
client_builder = client_builder.default_headers(headers.clone());
|
||||
}
|
||||
|
||||
let http_client = client_builder.build().map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Failed to build HTTP client for '{base_url}': {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let port = url.port_or_known_default().ok_or_else(|| {
|
||||
Error::Config(format!("Unable to determine port for Ollama URL '{}'", url))
|
||||
})?;
|
||||
|
||||
let mut ollama_client = Ollama::new_with_client(url.clone(), port, http_client.clone());
|
||||
if !headers.is_empty() {
|
||||
ollama_client.set_headers(Some(headers));
|
||||
}
|
||||
|
||||
Ok((ollama_client, http_client))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_literal_value() {
|
||||
@@ -1055,6 +1530,66 @@ mod tests {
|
||||
assert!(err.contains("https"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn explicit_local_mode_overrides_api_key() {
|
||||
let mut config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: Some("secret-key".to_string()),
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("local".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_prefers_explicit_local_base() {
|
||||
let config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: Some("secret-key".to_string()),
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
// simulate missing explicit mode; defaults to auto
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_no_local_probe_switches_to_cloud() {
|
||||
let mut config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_model_options_merges_parameters() {
|
||||
let mut parameters = ChatParameters::default();
|
||||
@@ -1093,3 +1628,127 @@ mod tests {
|
||||
assert!(caps.iter().any(|cap| cap == "vision"));
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
struct ProbeOverrideGuard;
|
||||
|
||||
#[cfg(test)]
|
||||
impl ProbeOverrideGuard {
|
||||
fn set(value: Option<bool>) -> Self {
|
||||
set_probe_override(value);
|
||||
ProbeOverrideGuard
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl Drop for ProbeOverrideGuard {
|
||||
fn drop(&mut self) {
|
||||
set_probe_override(None);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_successful_probe_prefers_local() {
|
||||
let _guard = ProbeOverrideGuard::set(Some(true));
|
||||
|
||||
let mut config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
assert!(probe_default_local_daemon(Duration::from_millis(1)));
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Local);
|
||||
assert_eq!(provider.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auto_mode_with_api_key_and_failed_probe_prefers_cloud() {
|
||||
let _guard = ProbeOverrideGuard::set(Some(false));
|
||||
|
||||
let mut config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: None,
|
||||
api_key: Some("secret-key".to_string()),
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
config.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("auto".to_string()),
|
||||
);
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
assert_eq!(provider.mode, OllamaMode::Cloud);
|
||||
assert_eq!(provider.base_url, CLOUD_BASE_URL);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn annotate_scope_status_adds_capabilities_for_unavailable_scopes() {
|
||||
let config = ProviderConfig {
|
||||
enabled: true,
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
api_key_env: None,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let provider = OllamaProvider::from_config(&config, None).expect("provider constructed");
|
||||
|
||||
let mut models = vec![ModelInfo {
|
||||
id: "llama3".to_string(),
|
||||
name: "Llama 3".to_string(),
|
||||
description: None,
|
||||
provider: "ollama".to_string(),
|
||||
context_window: None,
|
||||
capabilities: vec!["scope:local".to_string()],
|
||||
supports_tools: false,
|
||||
}];
|
||||
|
||||
block_on(async {
|
||||
{
|
||||
let mut cache = provider.scope_cache.write().await;
|
||||
let entry = cache.entry(OllamaMode::Cloud).or_default();
|
||||
entry.availability = ScopeAvailability::Unavailable;
|
||||
entry.last_error = Some("Cloud endpoint unreachable".to_string());
|
||||
entry.last_checked = Some(Instant::now());
|
||||
}
|
||||
|
||||
provider.annotate_scope_status(&mut models).await;
|
||||
});
|
||||
|
||||
let capabilities = &models[0].capabilities;
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == "scope-status:cloud:unavailable")
|
||||
);
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap.starts_with("scope-status-message:cloud:"))
|
||||
);
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap.starts_with("scope-status-age:cloud:"))
|
||||
);
|
||||
assert!(
|
||||
capabilities
|
||||
.iter()
|
||||
.any(|cap| cap == "scope-status-stale:cloud:0")
|
||||
);
|
||||
}
|
||||
|
||||
@@ -71,16 +71,19 @@ impl Router {
|
||||
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
||||
// Check routing rules first
|
||||
for rule in &self.routing_rules {
|
||||
if self.matches_pattern(&rule.model_pattern, model)
|
||||
&& let Some(provider) = self.registry.get(&rule.provider)
|
||||
{
|
||||
if !self.matches_pattern(&rule.model_pattern, model) {
|
||||
continue;
|
||||
}
|
||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(default) = &self.default_provider
|
||||
&& let Some(provider) = self.registry.get(default)
|
||||
if let Some(provider) = self
|
||||
.default_provider
|
||||
.as_ref()
|
||||
.and_then(|default| self.registry.get(default))
|
||||
{
|
||||
return Ok(provider);
|
||||
}
|
||||
|
||||
@@ -185,14 +185,20 @@ impl SandboxedProcess {
|
||||
if let Ok(output) = output {
|
||||
let version_str = String::from_utf8_lossy(&output.stdout);
|
||||
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
||||
if let Some(version_part) = version_str.split_whitespace().last()
|
||||
&& let Some((major, rest)) = version_part.split_once('.')
|
||||
&& let Some((minor, _patch)) = rest.split_once('.')
|
||||
&& let (Ok(maj), Ok(min)) = (major.parse::<u32>(), minor.parse::<u32>())
|
||||
{
|
||||
// --rlimit-as was added in 0.12.0
|
||||
return maj > 0 || (maj == 0 && min >= 12);
|
||||
}
|
||||
return version_str
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.and_then(|part| {
|
||||
part.split_once('.').and_then(|(major, rest)| {
|
||||
rest.split_once('.').and_then(|(minor, _)| {
|
||||
let maj = major.parse::<u32>().ok()?;
|
||||
let min = minor.parse::<u32>().ok()?;
|
||||
Some((maj, min))
|
||||
})
|
||||
})
|
||||
})
|
||||
.map(|(maj, min)| maj > 0 || (maj == 0 && min >= 12))
|
||||
.unwrap_or(false);
|
||||
}
|
||||
|
||||
// If we can't determine the version, assume it doesn't support it (safer default)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use crate::config::Config;
|
||||
use crate::consent::ConsentManager;
|
||||
use crate::config::{Config, McpResourceConfig, McpServerConfig};
|
||||
use crate::consent::{ConsentManager, ConsentScope};
|
||||
use crate::conversation::ConversationManager;
|
||||
use crate::credentials::CredentialManager;
|
||||
use crate::encryption::{self, VaultHandle};
|
||||
@@ -9,8 +9,10 @@ use crate::mcp::McpToolCall;
|
||||
use crate::mcp::client::McpClient;
|
||||
use crate::mcp::factory::McpClientFactory;
|
||||
use crate::mcp::permission::PermissionLayer;
|
||||
use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient};
|
||||
use crate::mode::Mode;
|
||||
use crate::model::{DetailedModelInfo, ModelManager};
|
||||
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
||||
use crate::providers::OllamaProvider;
|
||||
use crate::storage::{SessionMeta, StorageManager};
|
||||
use crate::types::{
|
||||
@@ -24,12 +26,15 @@ use crate::{
|
||||
ToolRegistry, WebScrapeTool, WebSearchDetailedTool, WebSearchTool,
|
||||
};
|
||||
use crate::{Error, Result};
|
||||
use chrono::Utc;
|
||||
use log::warn;
|
||||
use serde_json::Value;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use tokio::sync::Mutex as TokioMutex;
|
||||
use tokio::sync::mpsc::UnboundedSender;
|
||||
use uuid::Uuid;
|
||||
|
||||
pub enum SessionOutcome {
|
||||
@@ -40,6 +45,36 @@ pub enum SessionOutcome {
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ControllerEvent {
|
||||
ToolRequested {
|
||||
request_id: Uuid,
|
||||
message_id: Uuid,
|
||||
tool_name: String,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct PendingToolRequest {
|
||||
message_id: Uuid,
|
||||
tool_name: String,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolConsentResolution {
|
||||
pub request_id: Uuid,
|
||||
pub message_id: Uuid,
|
||||
pub tool_name: String,
|
||||
pub scope: ConsentScope,
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
fn extract_resource_content(value: &Value) -> Option<String> {
|
||||
match value {
|
||||
Value::Null => Some(String::new()),
|
||||
@@ -49,8 +84,8 @@ fn extract_resource_content(value: &Value) -> Option<String> {
|
||||
Value::Array(items) => {
|
||||
let mut segments = Vec::new();
|
||||
for item in items {
|
||||
if let Some(segment) = extract_resource_content(item)
|
||||
&& !segment.is_empty()
|
||||
if let Some(segment) =
|
||||
extract_resource_content(item).filter(|segment| !segment.is_empty())
|
||||
{
|
||||
segments.push(segment);
|
||||
}
|
||||
@@ -65,17 +100,19 @@ fn extract_resource_content(value: &Value) -> Option<String> {
|
||||
const PREFERRED_FIELDS: [&str; 6] =
|
||||
["content", "contents", "text", "value", "body", "data"];
|
||||
for key in PREFERRED_FIELDS.iter() {
|
||||
if let Some(inner) = map.get(*key)
|
||||
&& let Some(text) = extract_resource_content(inner)
|
||||
&& !text.is_empty()
|
||||
if let Some(text) = map
|
||||
.get(*key)
|
||||
.and_then(extract_resource_content)
|
||||
.filter(|text| !text.is_empty())
|
||||
{
|
||||
return Some(text);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(inner) = map.get("chunks")
|
||||
&& let Some(text) = extract_resource_content(inner)
|
||||
&& !text.is_empty()
|
||||
if let Some(text) = map
|
||||
.get("chunks")
|
||||
.and_then(extract_resource_content)
|
||||
.filter(|text| !text.is_empty())
|
||||
{
|
||||
return Some(text);
|
||||
}
|
||||
@@ -96,6 +133,7 @@ pub struct SessionController {
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
schema_validator: Arc<SchemaValidator>,
|
||||
mcp_client: Arc<dyn McpClient>,
|
||||
named_mcp_clients: HashMap<String, Arc<dyn McpClient>>,
|
||||
storage: Arc<StorageManager>,
|
||||
vault: Option<Arc<Mutex<VaultHandle>>>,
|
||||
master_key: Option<Arc<Vec<u8>>>,
|
||||
@@ -103,6 +141,9 @@ pub struct SessionController {
|
||||
ui: Arc<dyn UiController>,
|
||||
enable_code_tools: bool,
|
||||
current_mode: Mode,
|
||||
missing_oauth_servers: Vec<String>,
|
||||
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
||||
pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
|
||||
}
|
||||
|
||||
async fn build_tools(
|
||||
@@ -211,12 +252,119 @@ async fn build_tools(
|
||||
}
|
||||
|
||||
impl SessionController {
|
||||
async fn create_mcp_clients(
|
||||
config: Arc<TokioMutex<Config>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
schema_validator: Arc<SchemaValidator>,
|
||||
credential_manager: Option<Arc<CredentialManager>>,
|
||||
initial_mode: Mode,
|
||||
) -> Result<(
|
||||
Arc<dyn McpClient>,
|
||||
HashMap<String, Arc<dyn McpClient>>,
|
||||
Vec<String>,
|
||||
)> {
|
||||
let guard = config.lock().await;
|
||||
let config_arc = Arc::new(guard.clone());
|
||||
let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator);
|
||||
|
||||
let mut missing_oauth_servers = Vec::new();
|
||||
let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
||||
let (runtime, missing) =
|
||||
Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?;
|
||||
if missing {
|
||||
missing_oauth_servers.push(primary_cfg.name.clone());
|
||||
}
|
||||
runtime
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let base_client = factory.create_with_secrets(primary_runtime)?;
|
||||
let primary: Arc<dyn McpClient> =
|
||||
Arc::new(PermissionLayer::new(base_client, config_arc.clone()));
|
||||
primary.set_mode(initial_mode).await?;
|
||||
|
||||
let mut clients: HashMap<String, Arc<dyn McpClient>> = HashMap::new();
|
||||
if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
||||
clients.insert(primary_cfg.name.clone(), Arc::clone(&primary));
|
||||
}
|
||||
|
||||
for server_cfg in guard.effective_mcp_servers().iter().skip(1) {
|
||||
let (runtime, missing) =
|
||||
Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?;
|
||||
if missing {
|
||||
missing_oauth_servers.push(server_cfg.name.clone());
|
||||
}
|
||||
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime) {
|
||||
Ok(remote) => {
|
||||
let client: Arc<dyn McpClient> =
|
||||
Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone()));
|
||||
if let Err(err) = client.set_mode(initial_mode).await {
|
||||
warn!(
|
||||
"Failed to initialize MCP server '{}' in mode {:?}: {}",
|
||||
server_cfg.name, initial_mode, err
|
||||
);
|
||||
}
|
||||
clients.insert(server_cfg.name.clone(), Arc::clone(&client));
|
||||
}
|
||||
Err(err) => warn!(
|
||||
"Failed to initialize MCP server '{}': {}",
|
||||
server_cfg.name, err
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
drop(guard);
|
||||
|
||||
Ok((primary, clients, missing_oauth_servers))
|
||||
}
|
||||
|
||||
async fn runtime_secrets_for_server(
|
||||
credential_manager: Option<Arc<CredentialManager>>,
|
||||
server: &McpServerConfig,
|
||||
) -> Result<(Option<McpRuntimeSecrets>, bool)> {
|
||||
if let Some(oauth) = &server.oauth {
|
||||
if let Some(manager) = credential_manager {
|
||||
match manager.load_oauth_token(&server.name).await? {
|
||||
Some(token) => {
|
||||
if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) {
|
||||
return Ok((None, true));
|
||||
}
|
||||
let mut secrets = McpRuntimeSecrets::default();
|
||||
if let Some(env_name) = oauth.token_env.as_deref() {
|
||||
secrets
|
||||
.env_overrides
|
||||
.insert(env_name.to_string(), token.access_token.clone());
|
||||
}
|
||||
if matches!(
|
||||
server.transport.to_ascii_lowercase().as_str(),
|
||||
"http" | "websocket"
|
||||
) {
|
||||
let header_value =
|
||||
format!("{}{}", oauth.header_prefix(), token.access_token);
|
||||
secrets.http_header =
|
||||
Some((oauth.header_name().to_string(), header_value));
|
||||
}
|
||||
Ok((Some(secrets), false))
|
||||
}
|
||||
None => Ok((None, true)),
|
||||
}
|
||||
} else {
|
||||
Ok((None, true))
|
||||
}
|
||||
} else {
|
||||
Ok((None, false))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn new(
|
||||
provider: Arc<dyn Provider>,
|
||||
config: Config,
|
||||
storage: Arc<StorageManager>,
|
||||
ui: Arc<dyn UiController>,
|
||||
enable_code_tools: bool,
|
||||
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
||||
) -> Result<Self> {
|
||||
let config_arc = Arc::new(TokioMutex::new(config));
|
||||
// Acquire the config asynchronously to avoid blocking the runtime.
|
||||
@@ -292,19 +440,14 @@ impl SessionController {
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Create MCP client with permission layer
|
||||
let mcp_client: Arc<dyn McpClient> = {
|
||||
let guard = config_arc.lock().await;
|
||||
let factory = McpClientFactory::new(
|
||||
Arc::new(guard.clone()),
|
||||
tool_registry.clone(),
|
||||
schema_validator.clone(),
|
||||
);
|
||||
let base_client = factory.create()?;
|
||||
let client = Arc::new(PermissionLayer::new(base_client, Arc::new(guard.clone())));
|
||||
client.set_mode(initial_mode).await?;
|
||||
client
|
||||
};
|
||||
let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients(
|
||||
config_arc.clone(),
|
||||
tool_registry.clone(),
|
||||
schema_validator.clone(),
|
||||
credential_manager.clone(),
|
||||
initial_mode,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Self {
|
||||
provider,
|
||||
@@ -317,6 +460,7 @@ impl SessionController {
|
||||
tool_registry,
|
||||
schema_validator,
|
||||
mcp_client,
|
||||
named_mcp_clients,
|
||||
storage,
|
||||
vault: vault_handle,
|
||||
master_key,
|
||||
@@ -324,6 +468,9 @@ impl SessionController {
|
||||
ui,
|
||||
enable_code_tools,
|
||||
current_mode: initial_mode,
|
||||
missing_oauth_servers,
|
||||
event_tx,
|
||||
pending_tool_requests: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -355,6 +502,63 @@ impl SessionController {
|
||||
self.formatter.set_role_label_mode(mode);
|
||||
}
|
||||
|
||||
/// Return the configured resource references aggregated across scopes.
|
||||
pub async fn configured_resources(&self) -> Vec<McpResourceConfig> {
|
||||
let guard = self.config.lock().await;
|
||||
guard.effective_mcp_resources().to_vec()
|
||||
}
|
||||
|
||||
/// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`).
|
||||
pub async fn resolve_resource_reference(&self, reference: &str) -> Result<Option<String>> {
|
||||
let (server, uri) = match Self::split_resource_reference(reference) {
|
||||
Some(parts) => parts,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let resource_defined = {
|
||||
let guard = self.config.lock().await;
|
||||
guard.find_resource(&server, &uri).is_some()
|
||||
};
|
||||
|
||||
if !resource_defined {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let client = self
|
||||
.named_mcp_clients
|
||||
.get(&server)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{}' referenced by resource '{}' is not available",
|
||||
server, uri
|
||||
))
|
||||
})?;
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "resources/get".to_string(),
|
||||
arguments: json!({ "uri": uri, "path": uri }),
|
||||
};
|
||||
let response = client.call_tool(call).await?;
|
||||
if let Some(text) = extract_resource_content(&response.output) {
|
||||
return Ok(Some(text));
|
||||
}
|
||||
|
||||
let formatted = serde_json::to_string_pretty(&response.output)
|
||||
.unwrap_or_else(|_| response.output.to_string());
|
||||
Ok(Some(formatted))
|
||||
}
|
||||
|
||||
fn split_resource_reference(reference: &str) -> Option<(String, String)> {
|
||||
let trimmed = reference.trim();
|
||||
let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed);
|
||||
let (server, uri) = without_prefix.split_once(':')?;
|
||||
if server.is_empty() || uri.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some((server.to_string(), uri.to_string()))
|
||||
}
|
||||
|
||||
// Asynchronous access to the configuration (used internally).
|
||||
pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
||||
self.config.lock().await
|
||||
@@ -378,6 +582,21 @@ impl SessionController {
|
||||
self.config.clone()
|
||||
}
|
||||
|
||||
pub async fn reload_mcp_clients(&mut self) -> Result<()> {
|
||||
let (primary, named, missing) = Self::create_mcp_clients(
|
||||
self.config.clone(),
|
||||
self.tool_registry.clone(),
|
||||
self.schema_validator.clone(),
|
||||
self.credential_manager.clone(),
|
||||
self.current_mode,
|
||||
)
|
||||
.await?;
|
||||
self.mcp_client = primary;
|
||||
self.named_mcp_clients = named;
|
||||
self.missing_oauth_servers = missing;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
||||
let mut consent = self
|
||||
.consent_manager
|
||||
@@ -385,9 +604,10 @@ impl SessionController {
|
||||
.expect("Consent manager mutex poisoned");
|
||||
consent.grant_consent(tool_name, data_types, endpoints);
|
||||
|
||||
if let Some(vault) = &self.vault
|
||||
&& let Err(e) = consent.persist_to_vault(vault)
|
||||
{
|
||||
let Some(vault) = &self.vault else {
|
||||
return;
|
||||
};
|
||||
if let Err(e) = consent.persist_to_vault(vault) {
|
||||
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
||||
}
|
||||
}
|
||||
@@ -407,10 +627,13 @@ impl SessionController {
|
||||
consent.grant_consent_with_scope(tool_name, data_types, endpoints, scope);
|
||||
|
||||
// Only persist to vault for permanent consent
|
||||
if is_permanent
|
||||
&& let Some(vault) = &self.vault
|
||||
&& let Err(e) = consent.persist_to_vault(vault)
|
||||
{
|
||||
if !is_permanent {
|
||||
return;
|
||||
}
|
||||
let Some(vault) = &self.vault else {
|
||||
return;
|
||||
};
|
||||
if let Err(e) = consent.persist_to_vault(vault) {
|
||||
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
||||
}
|
||||
}
|
||||
@@ -525,6 +748,115 @@ impl SessionController {
|
||||
self.schema_validator.clone()
|
||||
}
|
||||
|
||||
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
||||
self.credential_manager.clone()
|
||||
}
|
||||
|
||||
pub fn pending_oauth_servers(&self) -> Vec<String> {
|
||||
self.missing_oauth_servers.clone()
|
||||
}
|
||||
|
||||
pub async fn start_oauth_device_flow(&self, server: &str) -> Result<DeviceAuthorization> {
|
||||
let oauth_config = {
|
||||
let config = self.config.lock().await;
|
||||
let server_cfg = config
|
||||
.effective_mcp_servers()
|
||||
.iter()
|
||||
.find(|entry| entry.name == server)
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{server}' is configured"))
|
||||
})?;
|
||||
server_cfg.oauth.clone().ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{server}' does not define an OAuth configuration"
|
||||
))
|
||||
})?
|
||||
};
|
||||
|
||||
let client = OAuthClient::new(oauth_config)?;
|
||||
client.start_device_authorization().await
|
||||
}
|
||||
|
||||
pub async fn poll_oauth_device_flow(
|
||||
&mut self,
|
||||
server: &str,
|
||||
authorization: &DeviceAuthorization,
|
||||
) -> Result<DevicePollState> {
|
||||
let oauth_config = {
|
||||
let config = self.config.lock().await;
|
||||
let server_cfg = config
|
||||
.effective_mcp_servers()
|
||||
.iter()
|
||||
.find(|entry| entry.name == server)
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{server}' is configured"))
|
||||
})?;
|
||||
server_cfg.oauth.clone().ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"MCP server '{server}' does not define an OAuth configuration"
|
||||
))
|
||||
})?
|
||||
};
|
||||
|
||||
let client = OAuthClient::new(oauth_config)?;
|
||||
match client.poll_device_token(authorization).await? {
|
||||
DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }),
|
||||
DevicePollState::Complete(token) => {
|
||||
let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"OAuth token storage requires encrypted local data; set \
|
||||
privacy.encrypt_local_data = true in the configuration."
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
manager.store_oauth_token(server, &token).await?;
|
||||
self.missing_oauth_servers.retain(|entry| entry != server);
|
||||
|
||||
Ok(DevicePollState::Complete(token))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> {
|
||||
let mut entries = Vec::new();
|
||||
for (server, client) in self.named_mcp_clients.iter() {
|
||||
let server_name = server.clone();
|
||||
let client = Arc::clone(client);
|
||||
match client.list_tools().await {
|
||||
Ok(tools) => {
|
||||
for descriptor in tools {
|
||||
entries.push((server_name.clone(), descriptor));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!(
|
||||
"Failed to list tools for MCP server '{}': {}",
|
||||
server_name, err
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
entries
|
||||
}
|
||||
|
||||
pub async fn call_mcp_tool(
|
||||
&self,
|
||||
server: &str,
|
||||
tool: &str,
|
||||
arguments: Value,
|
||||
) -> Result<crate::mcp::McpToolResponse> {
|
||||
let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| {
|
||||
Error::Config(format!("No MCP server named '{}' is registered", server))
|
||||
})?;
|
||||
client
|
||||
.call_tool(McpToolCall {
|
||||
name: tool.to_string(),
|
||||
arguments,
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
||||
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
||||
}
|
||||
@@ -926,14 +1258,84 @@ impl SessionController {
|
||||
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
||||
}
|
||||
|
||||
pub fn check_streaming_tool_calls(&self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
||||
self.conversation
|
||||
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
||||
let maybe_calls = self
|
||||
.conversation
|
||||
.active()
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.id == message_id)
|
||||
.and_then(|m| m.tool_calls.clone())
|
||||
.filter(|calls| !calls.is_empty())
|
||||
.filter(|calls| !calls.is_empty());
|
||||
|
||||
let calls = maybe_calls?;
|
||||
|
||||
if !self
|
||||
.pending_tool_requests
|
||||
.values()
|
||||
.any(|pending| pending.message_id == message_id)
|
||||
{
|
||||
if let Some((tool_name, data_types, endpoints)) =
|
||||
self.check_tools_consent_needed(&calls).into_iter().next()
|
||||
{
|
||||
let request_id = Uuid::new_v4();
|
||||
let pending = PendingToolRequest {
|
||||
message_id,
|
||||
tool_name: tool_name.clone(),
|
||||
data_types: data_types.clone(),
|
||||
endpoints: endpoints.clone(),
|
||||
tool_calls: calls.clone(),
|
||||
};
|
||||
self.pending_tool_requests.insert(request_id, pending);
|
||||
|
||||
if let Some(tx) = &self.event_tx {
|
||||
let _ = tx.send(ControllerEvent::ToolRequested {
|
||||
request_id,
|
||||
message_id,
|
||||
tool_name,
|
||||
data_types,
|
||||
endpoints,
|
||||
tool_calls: calls.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(calls)
|
||||
}
|
||||
|
||||
pub fn resolve_tool_consent(
|
||||
&mut self,
|
||||
request_id: Uuid,
|
||||
scope: ConsentScope,
|
||||
) -> Result<ToolConsentResolution> {
|
||||
let pending = self
|
||||
.pending_tool_requests
|
||||
.remove(&request_id)
|
||||
.ok_or_else(|| {
|
||||
Error::InvalidInput(format!("Unknown tool consent request: {}", request_id))
|
||||
})?;
|
||||
|
||||
let PendingToolRequest {
|
||||
message_id,
|
||||
tool_name,
|
||||
data_types,
|
||||
endpoints,
|
||||
tool_calls,
|
||||
..
|
||||
} = pending;
|
||||
|
||||
if !matches!(scope, ConsentScope::Denied) {
|
||||
self.grant_consent_with_scope(&tool_name, data_types, endpoints, scope.clone());
|
||||
}
|
||||
|
||||
Ok(ToolConsentResolution {
|
||||
request_id,
|
||||
message_id,
|
||||
tool_name,
|
||||
scope,
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
||||
@@ -985,3 +1387,195 @@ impl SessionController {
|
||||
Ok("Empty conversation".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Provider;
|
||||
use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig};
|
||||
use crate::llm::test_utils::MockProvider;
|
||||
use crate::storage::StorageManager;
|
||||
use crate::ui::NoOpUiController;
|
||||
use chrono::Utc;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
const SERVER_NAME: &str = "oauth-test";
|
||||
|
||||
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "owlen-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string()],
|
||||
token_env: Some("OAUTH_TOKEN".to_string()),
|
||||
header: Some("Authorization".to_string()),
|
||||
header_prefix: Some("Bearer ".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_config(server: &MockServer) -> Config {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
let oauth = build_oauth_config(server);
|
||||
|
||||
let mut env = HashMap::new();
|
||||
env.insert("OWLEN_ENV".to_string(), "test".to_string());
|
||||
|
||||
config.mcp_servers = vec![McpServerConfig {
|
||||
name: SERVER_NAME.to_string(),
|
||||
command: server.url("/mcp"),
|
||||
args: Vec::new(),
|
||||
transport: "http".to_string(),
|
||||
env,
|
||||
oauth: Some(oauth),
|
||||
}];
|
||||
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
config
|
||||
}
|
||||
|
||||
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
||||
unsafe {
|
||||
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
||||
}
|
||||
|
||||
let temp_dir = tempdir().expect("tempdir");
|
||||
let storage_path = temp_dir.path().join("owlen.db");
|
||||
let storage = Arc::new(
|
||||
StorageManager::with_database_path(storage_path)
|
||||
.await
|
||||
.expect("storage"),
|
||||
);
|
||||
|
||||
let config = build_config(server);
|
||||
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default()) as Arc<dyn Provider>;
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
|
||||
let session = SessionController::new(provider, config, storage, ui, false, None)
|
||||
.await
|
||||
.expect("session");
|
||||
|
||||
(session, temp_dir)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_oauth_device_flow_returns_details() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-abc",
|
||||
"user_code": "ABCD-1234",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234",
|
||||
"expires_in": 600,
|
||||
"interval": 5,
|
||||
"message": "Enter the code to continue."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let (session, _dir) = build_session(&server).await;
|
||||
let authorization = session
|
||||
.start_oauth_device_flow(SERVER_NAME)
|
||||
.await
|
||||
.expect("device flow");
|
||||
|
||||
assert_eq!(authorization.user_code, "ABCD-1234");
|
||||
assert_eq!(
|
||||
authorization.verification_uri_complete.as_deref(),
|
||||
Some("https://example.test/activate?user_code=ABCD-1234")
|
||||
);
|
||||
assert!(authorization.expires_at > Utc::now());
|
||||
device.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_oauth_device_flow_stores_token_and_updates_state() {
|
||||
let server = MockServer::start_async().await;
|
||||
|
||||
let device = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-xyz",
|
||||
"user_code": "WXYZ-9999",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999",
|
||||
"expires_in": 600,
|
||||
"interval": 5
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("device_code=device-xyz");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "new-access-token",
|
||||
"refresh_token": "refresh-token",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let (mut session, _dir) = build_session(&server).await;
|
||||
assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]);
|
||||
|
||||
let authorization = session
|
||||
.start_oauth_device_flow(SERVER_NAME)
|
||||
.await
|
||||
.expect("device flow");
|
||||
|
||||
match session
|
||||
.poll_oauth_device_flow(SERVER_NAME, &authorization)
|
||||
.await
|
||||
.expect("token poll")
|
||||
{
|
||||
DevicePollState::Complete(token_info) => {
|
||||
assert_eq!(token_info.access_token, "new-access-token");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token"));
|
||||
}
|
||||
other => panic!("expected token completion, got {other:?}"),
|
||||
}
|
||||
|
||||
assert!(
|
||||
session
|
||||
.pending_oauth_servers()
|
||||
.iter()
|
||||
.all(|entry| entry != SERVER_NAME),
|
||||
"server should be removed from pending list"
|
||||
);
|
||||
|
||||
let stored = session
|
||||
.credential_manager()
|
||||
.expect("credential manager")
|
||||
.load_oauth_token(SERVER_NAME)
|
||||
.await
|
||||
.expect("load token")
|
||||
.expect("token present");
|
||||
|
||||
assert_eq!(stored.access_token, "new-access-token");
|
||||
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
||||
|
||||
device.assert_async().await;
|
||||
token.assert_async().await;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,14 +3,14 @@
|
||||
use std::fmt;
|
||||
|
||||
/// High-level application state reported by the UI loop.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum AppState {
|
||||
Running,
|
||||
Quit,
|
||||
}
|
||||
|
||||
/// Vim-style input modes supported by the TUI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum InputMode {
|
||||
Normal,
|
||||
Editing,
|
||||
@@ -21,6 +21,8 @@ pub enum InputMode {
|
||||
Command,
|
||||
SessionBrowser,
|
||||
ThemeBrowser,
|
||||
RepoSearch,
|
||||
SymbolSearch,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputMode {
|
||||
@@ -35,14 +37,17 @@ impl fmt::Display for InputMode {
|
||||
InputMode::Command => "Command",
|
||||
InputMode::SessionBrowser => "Sessions",
|
||||
InputMode::ThemeBrowser => "Themes",
|
||||
InputMode::RepoSearch => "Search",
|
||||
InputMode::SymbolSearch => "Symbols",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents which panel is currently focused in the TUI layout.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum FocusedPanel {
|
||||
Files,
|
||||
Chat,
|
||||
Thinking,
|
||||
Input,
|
||||
|
||||
@@ -50,14 +50,14 @@ impl StorageManager {
|
||||
|
||||
/// Create a storage manager using the provided database path
|
||||
pub async fn with_database_path(database_path: PathBuf) -> Result<Self> {
|
||||
if let Some(parent) = database_path.parent()
|
||||
&& !parent.exists()
|
||||
{
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to create database directory {parent:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
if let Some(parent) = database_path.parent() {
|
||||
if !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to create database directory {parent:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!(
|
||||
@@ -431,13 +431,13 @@ impl StorageManager {
|
||||
}
|
||||
}
|
||||
|
||||
if migrated > 0
|
||||
&& let Err(err) = archive_legacy_directory(&legacy_dir)
|
||||
{
|
||||
println!(
|
||||
"Warning: migrated sessions but failed to archive legacy directory: {}",
|
||||
err
|
||||
);
|
||||
if migrated > 0 {
|
||||
if let Err(err) = archive_legacy_directory(&legacy_dir) {
|
||||
println!(
|
||||
"Warning: migrated sessions but failed to archive legacy directory: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
println!("Migrated {} legacy sessions.", migrated);
|
||||
|
||||
@@ -36,6 +36,42 @@ pub struct Theme {
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub unfocused_panel_border: Color,
|
||||
|
||||
/// Foreground color for the active pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_focus_beacon_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub focus_beacon_fg: Color,
|
||||
|
||||
/// Background color for the active pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_focus_beacon_bg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub focus_beacon_bg: Color,
|
||||
|
||||
/// Foreground color for the inactive pane beacon (`▌`)
|
||||
#[serde(default = "Theme::default_unfocused_beacon_fg")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub unfocused_beacon_fg: Color,
|
||||
|
||||
/// Title color for active pane headers
|
||||
#[serde(default = "Theme::default_pane_header_active")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_header_active: Color,
|
||||
|
||||
/// Title color for inactive pane headers
|
||||
#[serde(default = "Theme::default_pane_header_inactive")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_header_inactive: Color,
|
||||
|
||||
/// Hint text color used within pane headers
|
||||
#[serde(default = "Theme::default_pane_hint_text")]
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
pub pane_hint_text: Color,
|
||||
|
||||
/// Color for user message role indicator
|
||||
#[serde(deserialize_with = "deserialize_color")]
|
||||
#[serde(serialize_with = "serialize_color")]
|
||||
@@ -313,6 +349,30 @@ impl Theme {
|
||||
Color::Cyan
|
||||
}
|
||||
|
||||
const fn default_focus_beacon_fg() -> Color {
|
||||
Color::LightMagenta
|
||||
}
|
||||
|
||||
const fn default_focus_beacon_bg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
|
||||
const fn default_unfocused_beacon_fg() -> Color {
|
||||
Color::DarkGray
|
||||
}
|
||||
|
||||
const fn default_pane_header_active() -> Color {
|
||||
Color::White
|
||||
}
|
||||
|
||||
const fn default_pane_header_inactive() -> Color {
|
||||
Color::Gray
|
||||
}
|
||||
|
||||
const fn default_pane_hint_text() -> Color {
|
||||
Color::DarkGray
|
||||
}
|
||||
|
||||
const fn default_operating_chat_fg() -> Color {
|
||||
Color::Black
|
||||
}
|
||||
@@ -472,46 +532,52 @@ fn default_dark() -> Theme {
|
||||
name: "default_dark".to_string(),
|
||||
text: Color::White,
|
||||
background: Color::Black,
|
||||
focused_panel_border: Color::LightMagenta,
|
||||
unfocused_panel_border: Color::Rgb(95, 20, 135),
|
||||
focused_panel_border: Color::Rgb(216, 160, 255),
|
||||
unfocused_panel_border: Color::Rgb(137, 82, 204),
|
||||
focus_beacon_fg: Color::Rgb(248, 229, 255),
|
||||
focus_beacon_bg: Color::Rgb(38, 10, 58),
|
||||
unfocused_beacon_fg: Color::Rgb(130, 130, 130),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Color::Rgb(210, 210, 210),
|
||||
pane_hint_text: Color::Rgb(210, 210, 210),
|
||||
user_message_role: Color::LightBlue,
|
||||
assistant_message_role: Color::Yellow,
|
||||
tool_output: Color::Gray,
|
||||
thinking_panel_title: Color::LightMagenta,
|
||||
command_bar_background: Color::Black,
|
||||
status_background: Color::Black,
|
||||
mode_normal: Color::LightBlue,
|
||||
mode_editing: Color::LightGreen,
|
||||
mode_model_selection: Color::LightYellow,
|
||||
mode_provider_selection: Color::LightCyan,
|
||||
mode_help: Color::LightMagenta,
|
||||
mode_visual: Color::Magenta,
|
||||
mode_command: Color::Yellow,
|
||||
selection_bg: Color::LightBlue,
|
||||
tool_output: Color::Rgb(200, 200, 200),
|
||||
thinking_panel_title: Color::Rgb(234, 182, 255),
|
||||
command_bar_background: Color::Rgb(10, 10, 10),
|
||||
status_background: Color::Rgb(12, 12, 12),
|
||||
mode_normal: Color::Rgb(117, 200, 255),
|
||||
mode_editing: Color::Rgb(144, 242, 170),
|
||||
mode_model_selection: Color::Rgb(255, 226, 140),
|
||||
mode_provider_selection: Color::Rgb(164, 235, 255),
|
||||
mode_help: Color::Rgb(234, 182, 255),
|
||||
mode_visual: Color::Rgb(255, 170, 255),
|
||||
mode_command: Color::Rgb(255, 220, 120),
|
||||
selection_bg: Color::Rgb(56, 140, 240),
|
||||
selection_fg: Color::Black,
|
||||
cursor: Color::Magenta,
|
||||
cursor: Color::Rgb(255, 196, 255),
|
||||
code_block_background: Color::Rgb(25, 25, 25),
|
||||
code_block_border: Color::LightMagenta,
|
||||
code_block_border: Color::Rgb(216, 160, 255),
|
||||
code_block_text: Color::White,
|
||||
code_block_keyword: Color::Yellow,
|
||||
code_block_string: Color::LightGreen,
|
||||
code_block_comment: Color::Gray,
|
||||
placeholder: Color::DarkGray,
|
||||
code_block_keyword: Color::Rgb(255, 220, 120),
|
||||
code_block_string: Color::Rgb(144, 242, 170),
|
||||
code_block_comment: Color::Rgb(170, 170, 170),
|
||||
placeholder: Color::Rgb(180, 180, 180),
|
||||
error: Color::Red,
|
||||
info: Color::LightGreen,
|
||||
agent_thought: Color::LightBlue,
|
||||
agent_action: Color::Yellow,
|
||||
agent_action_input: Color::LightCyan,
|
||||
agent_observation: Color::LightGreen,
|
||||
agent_final_answer: Color::Magenta,
|
||||
info: Color::Rgb(144, 242, 170),
|
||||
agent_thought: Color::Rgb(117, 200, 255),
|
||||
agent_action: Color::Rgb(255, 220, 120),
|
||||
agent_action_input: Color::Rgb(164, 235, 255),
|
||||
agent_observation: Color::Rgb(144, 242, 170),
|
||||
agent_final_answer: Color::Rgb(255, 170, 255),
|
||||
agent_badge_running_fg: Color::Black,
|
||||
agent_badge_running_bg: Color::Yellow,
|
||||
agent_badge_idle_fg: Color::Black,
|
||||
agent_badge_idle_bg: Color::Cyan,
|
||||
operating_chat_fg: Color::Black,
|
||||
operating_chat_bg: Color::Blue,
|
||||
operating_chat_bg: Color::Rgb(117, 200, 255),
|
||||
operating_code_fg: Color::Black,
|
||||
operating_code_bg: Color::Magenta,
|
||||
operating_code_bg: Color::Rgb(255, 170, 255),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -523,6 +589,12 @@ fn default_light() -> Theme {
|
||||
background: Color::White,
|
||||
focused_panel_border: Color::Rgb(74, 144, 226),
|
||||
unfocused_panel_border: Color::Rgb(221, 221, 221),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(0, 85, 164),
|
||||
assistant_message_role: Color::Rgb(142, 68, 173),
|
||||
tool_output: Color::Gray,
|
||||
@@ -572,7 +644,13 @@ fn gruvbox() -> Theme {
|
||||
background: Color::Rgb(40, 40, 40), // #282828
|
||||
focused_panel_border: Color::Rgb(254, 128, 25), // #fe8019 (orange)
|
||||
unfocused_panel_border: Color::Rgb(124, 111, 100), // #7c6f64
|
||||
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(184, 187, 38), // #b8bb26 (green)
|
||||
assistant_message_role: Color::Rgb(131, 165, 152), // #83a598 (blue)
|
||||
tool_output: Color::Rgb(146, 131, 116),
|
||||
thinking_panel_title: Color::Rgb(211, 134, 155), // #d3869b (purple)
|
||||
@@ -617,11 +695,17 @@ fn gruvbox() -> Theme {
|
||||
fn dracula() -> Theme {
|
||||
Theme {
|
||||
name: "dracula".to_string(),
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(40, 42, 54), // #282a36
|
||||
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
||||
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(40, 42, 54), // #282a36
|
||||
focused_panel_border: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
unfocused_panel_border: Color::Rgb(68, 71, 90), // #44475a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(139, 233, 253), // #8be9fd (cyan)
|
||||
assistant_message_role: Color::Rgb(255, 121, 198), // #ff79c6 (pink)
|
||||
tool_output: Color::Rgb(98, 114, 164),
|
||||
thinking_panel_title: Color::Rgb(189, 147, 249), // #bd93f9 (purple)
|
||||
@@ -670,6 +754,12 @@ fn solarized() -> Theme {
|
||||
background: Color::Rgb(0, 43, 54), // #002b36 (base03)
|
||||
focused_panel_border: Color::Rgb(38, 139, 210), // #268bd2 (blue)
|
||||
unfocused_panel_border: Color::Rgb(7, 54, 66), // #073642 (base02)
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(42, 161, 152), // #2aa198 (cyan)
|
||||
assistant_message_role: Color::Rgb(203, 75, 22), // #cb4b16 (orange)
|
||||
tool_output: Color::Rgb(101, 123, 131),
|
||||
@@ -719,6 +809,12 @@ fn midnight_ocean() -> Theme {
|
||||
background: Color::Rgb(13, 17, 23),
|
||||
focused_panel_border: Color::Rgb(88, 166, 255),
|
||||
unfocused_panel_border: Color::Rgb(48, 54, 61),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(121, 192, 255),
|
||||
assistant_message_role: Color::Rgb(137, 221, 255),
|
||||
tool_output: Color::Rgb(84, 110, 122),
|
||||
@@ -764,11 +860,17 @@ fn midnight_ocean() -> Theme {
|
||||
fn rose_pine() -> Theme {
|
||||
Theme {
|
||||
name: "rose-pine".to_string(),
|
||||
text: Color::Rgb(224, 222, 244), // #e0def4
|
||||
background: Color::Rgb(25, 23, 36), // #191724
|
||||
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
||||
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
||||
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
||||
text: Color::Rgb(224, 222, 244), // #e0def4
|
||||
background: Color::Rgb(25, 23, 36), // #191724
|
||||
focused_panel_border: Color::Rgb(235, 111, 146), // #eb6f92 (love)
|
||||
unfocused_panel_border: Color::Rgb(38, 35, 58), // #26233a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(49, 116, 143), // #31748f (foam)
|
||||
assistant_message_role: Color::Rgb(156, 207, 216), // #9ccfd8 (foam light)
|
||||
tool_output: Color::Rgb(110, 106, 134),
|
||||
thinking_panel_title: Color::Rgb(196, 167, 231), // #c4a7e7 (iris)
|
||||
@@ -813,11 +915,17 @@ fn rose_pine() -> Theme {
|
||||
fn monokai() -> Theme {
|
||||
Theme {
|
||||
name: "monokai".to_string(),
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(39, 40, 34), // #272822
|
||||
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
||||
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
||||
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
||||
text: Color::Rgb(248, 248, 242), // #f8f8f2
|
||||
background: Color::Rgb(39, 40, 34), // #272822
|
||||
focused_panel_border: Color::Rgb(249, 38, 114), // #f92672 (pink)
|
||||
unfocused_panel_border: Color::Rgb(117, 113, 94), // #75715e
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(102, 217, 239), // #66d9ef (cyan)
|
||||
assistant_message_role: Color::Rgb(174, 129, 255), // #ae81ff (purple)
|
||||
tool_output: Color::Rgb(117, 113, 94),
|
||||
thinking_panel_title: Color::Rgb(230, 219, 116), // #e6db74 (yellow)
|
||||
@@ -862,11 +970,17 @@ fn monokai() -> Theme {
|
||||
fn material_dark() -> Theme {
|
||||
Theme {
|
||||
name: "material-dark".to_string(),
|
||||
text: Color::Rgb(238, 255, 255), // #eeffff
|
||||
background: Color::Rgb(38, 50, 56), // #263238
|
||||
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
||||
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
||||
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
||||
text: Color::Rgb(238, 255, 255), // #eeffff
|
||||
background: Color::Rgb(38, 50, 56), // #263238
|
||||
focused_panel_border: Color::Rgb(128, 203, 196), // #80cbc4 (cyan)
|
||||
unfocused_panel_border: Color::Rgb(84, 110, 122), // #546e7a
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(130, 170, 255), // #82aaff (blue)
|
||||
assistant_message_role: Color::Rgb(199, 146, 234), // #c792ea (purple)
|
||||
tool_output: Color::Rgb(84, 110, 122),
|
||||
thinking_panel_title: Color::Rgb(255, 203, 107), // #ffcb6b (yellow)
|
||||
@@ -915,6 +1029,12 @@ fn material_light() -> Theme {
|
||||
background: Color::Rgb(236, 239, 241),
|
||||
focused_panel_border: Color::Rgb(0, 150, 136),
|
||||
unfocused_panel_border: Color::Rgb(176, 190, 197),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(68, 138, 255),
|
||||
assistant_message_role: Color::Rgb(124, 77, 255),
|
||||
tool_output: Color::Rgb(144, 164, 174),
|
||||
@@ -964,6 +1084,12 @@ fn grayscale_high_contrast() -> Theme {
|
||||
background: Color::Black,
|
||||
focused_panel_border: Color::White,
|
||||
unfocused_panel_border: Color::Rgb(76, 76, 76),
|
||||
focus_beacon_fg: Theme::default_focus_beacon_fg(),
|
||||
focus_beacon_bg: Theme::default_focus_beacon_bg(),
|
||||
unfocused_beacon_fg: Theme::default_unfocused_beacon_fg(),
|
||||
pane_header_active: Theme::default_pane_header_active(),
|
||||
pane_header_inactive: Theme::default_pane_header_inactive(),
|
||||
pane_hint_text: Theme::default_pane_hint_text(),
|
||||
user_message_role: Color::Rgb(240, 240, 240),
|
||||
assistant_message_role: Color::Rgb(214, 214, 214),
|
||||
tool_output: Color::Rgb(189, 189, 189),
|
||||
|
||||
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
310
crates/owlen-core/tests/agent_tool_flow.rs
Normal file
@@ -0,0 +1,310 @@
|
||||
use std::{any::Any, collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use owlen_core::{
|
||||
Config, Error, Mode, Provider,
|
||||
config::McpMode,
|
||||
consent::ConsentScope,
|
||||
mcp::{
|
||||
McpClient, McpToolCall, McpToolDescriptor, McpToolResponse,
|
||||
failover::{FailoverMcpClient, ServerEntry},
|
||||
},
|
||||
session::{ControllerEvent, SessionController, SessionOutcome},
|
||||
storage::StorageManager,
|
||||
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, ToolCall},
|
||||
ui::NoOpUiController,
|
||||
};
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct StreamingToolProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StreamingToolProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock-streaming-provider"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> owlen_core::Result<Vec<ModelInfo>> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: "mock-model".into(),
|
||||
name: "Mock Model".into(),
|
||||
description: Some("A mock model that emits tool calls".into()),
|
||||
provider: self.name().into(),
|
||||
context_window: Some(4096),
|
||||
capabilities: vec!["chat".into(), "tools".into()],
|
||||
supports_tools: true,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, _request: ChatRequest) -> owlen_core::Result<ChatResponse> {
|
||||
let mut message = Message::assistant("tool-call".to_string());
|
||||
message.tool_calls = Some(vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "resources/write".to_string(),
|
||||
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||
}]);
|
||||
|
||||
Ok(ChatResponse {
|
||||
message,
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
_request: ChatRequest,
|
||||
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||
let mut first_chunk = Message::assistant(
|
||||
"Thought: need to update README.\nAction: resources/write".to_string(),
|
||||
);
|
||||
first_chunk.tool_calls = Some(vec![ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "resources/write".to_string(),
|
||||
arguments: serde_json::json!({"path": "README.md", "content": "hello"}),
|
||||
}]);
|
||||
|
||||
let chunk = ChatResponse {
|
||||
message: first_chunk,
|
||||
usage: None,
|
||||
is_streaming: true,
|
||||
is_final: false,
|
||||
};
|
||||
|
||||
Ok(Box::pin(futures::stream::iter(vec![Ok(chunk)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
fn tool_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "web_search".to_string(),
|
||||
description: "search".to_string(),
|
||||
input_schema: serde_json::json!({"type": "object"}),
|
||||
requires_network: true,
|
||||
requires_filesystem: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
struct TimeoutClient;
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for TimeoutClient {
|
||||
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||
Ok(vec![tool_descriptor()])
|
||||
}
|
||||
|
||||
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||
Err(Error::Network(
|
||||
"timeout while contacting remote web search endpoint".into(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct CachedResponseClient {
|
||||
response: Arc<McpToolResponse>,
|
||||
}
|
||||
|
||||
impl CachedResponseClient {
|
||||
fn new() -> Self {
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("source".to_string(), "cache".to_string());
|
||||
metadata.insert("cached".to_string(), "true".to_string());
|
||||
|
||||
let response = McpToolResponse {
|
||||
name: "web_search".to_string(),
|
||||
success: true,
|
||||
output: serde_json::json!({
|
||||
"query": "rust",
|
||||
"results": [
|
||||
{"title": "Rust Programming Language", "url": "https://www.rust-lang.org"}
|
||||
],
|
||||
"note": "cached result"
|
||||
}),
|
||||
metadata,
|
||||
duration_ms: 0,
|
||||
};
|
||||
|
||||
Self {
|
||||
response: Arc::new(response),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for CachedResponseClient {
|
||||
async fn list_tools(&self) -> owlen_core::Result<Vec<McpToolDescriptor>> {
|
||||
Ok(vec![tool_descriptor()])
|
||||
}
|
||||
|
||||
async fn call_tool(&self, _call: McpToolCall) -> owlen_core::Result<McpToolResponse> {
|
||||
Ok((*self.response).clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn streaming_file_write_consent_denied_returns_resolution() {
|
||||
let temp_dir = tempdir().expect("temp dir");
|
||||
let storage = StorageManager::with_database_path(temp_dir.path().join("owlen-tests.db"))
|
||||
.await
|
||||
.expect("storage");
|
||||
|
||||
let mut config = Config::default();
|
||||
config.general.enable_streaming = true;
|
||||
config.privacy.encrypt_local_data = false;
|
||||
config.privacy.require_consent_per_session = true;
|
||||
config.general.default_model = Some("mock-model".into());
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config
|
||||
.refresh_mcp_servers(None)
|
||||
.expect("refresh MCP servers");
|
||||
|
||||
let provider: Arc<dyn Provider> = Arc::new(StreamingToolProvider);
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let (event_tx, mut event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
|
||||
let mut session = SessionController::new(
|
||||
provider,
|
||||
config,
|
||||
Arc::new(storage),
|
||||
ui,
|
||||
true,
|
||||
Some(event_tx),
|
||||
)
|
||||
.await
|
||||
.expect("session controller");
|
||||
|
||||
session
|
||||
.set_operating_mode(Mode::Code)
|
||||
.await
|
||||
.expect("code mode");
|
||||
|
||||
let outcome = session
|
||||
.send_message(
|
||||
"Please write to README".to_string(),
|
||||
ChatParameters {
|
||||
stream: true,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.expect("send message");
|
||||
|
||||
let (response_id, mut stream) = if let SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
} = outcome
|
||||
{
|
||||
(response_id, stream)
|
||||
} else {
|
||||
panic!("expected streaming outcome");
|
||||
};
|
||||
|
||||
session
|
||||
.mark_stream_placeholder(response_id, "▌")
|
||||
.expect("placeholder");
|
||||
|
||||
let chunk = stream
|
||||
.next()
|
||||
.await
|
||||
.expect("stream chunk")
|
||||
.expect("chunk result");
|
||||
session
|
||||
.apply_stream_chunk(response_id, &chunk)
|
||||
.expect("apply chunk");
|
||||
|
||||
let tool_calls = session
|
||||
.check_streaming_tool_calls(response_id)
|
||||
.expect("tool calls");
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].name, "resources/write");
|
||||
|
||||
let event = event_rx.recv().await.expect("controller event");
|
||||
let request_id = match event {
|
||||
ControllerEvent::ToolRequested {
|
||||
request_id,
|
||||
tool_name,
|
||||
data_types,
|
||||
endpoints,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(tool_name, "resources/write");
|
||||
assert!(data_types.iter().any(|t| t.contains("file")));
|
||||
assert!(endpoints.iter().any(|e| e.contains("filesystem")));
|
||||
request_id
|
||||
}
|
||||
};
|
||||
|
||||
let resolution = session
|
||||
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||
.expect("resolution");
|
||||
assert_eq!(resolution.scope, ConsentScope::Denied);
|
||||
assert_eq!(resolution.tool_name, "resources/write");
|
||||
assert_eq!(resolution.tool_calls.len(), tool_calls.len());
|
||||
|
||||
let err = session
|
||||
.resolve_tool_consent(request_id, ConsentScope::Denied)
|
||||
.expect_err("second resolution should fail");
|
||||
matches!(err, Error::InvalidInput(_));
|
||||
|
||||
let conversation = session.conversation().clone();
|
||||
let assistant = conversation
|
||||
.messages
|
||||
.iter()
|
||||
.find(|message| message.role == Role::Assistant)
|
||||
.expect("assistant message present");
|
||||
assert!(
|
||||
assistant
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.and_then(|calls| calls.first())
|
||||
.is_some_and(|call| call.name == "resources/write"),
|
||||
"stream chunk should capture the tool call on the assistant message"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn web_tool_timeout_fails_over_to_cached_result() {
|
||||
let primary: Arc<dyn McpClient> = Arc::new(TimeoutClient);
|
||||
let cached = CachedResponseClient::new();
|
||||
let backup: Arc<dyn McpClient> = Arc::new(cached.clone());
|
||||
|
||||
let client = FailoverMcpClient::with_servers(vec![
|
||||
ServerEntry::new("primary".into(), primary, 1),
|
||||
ServerEntry::new("cache".into(), backup, 2),
|
||||
]);
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "web_search".to_string(),
|
||||
arguments: serde_json::json!({ "query": "rust", "max_results": 3 }),
|
||||
};
|
||||
|
||||
let response = client.call_tool(call.clone()).await.expect("fallback");
|
||||
|
||||
assert_eq!(response.name, "web_search");
|
||||
assert_eq!(
|
||||
response.metadata.get("source").map(String::as_str),
|
||||
Some("cache")
|
||||
);
|
||||
assert_eq!(
|
||||
response.output.get("note").and_then(|value| value.as_str()),
|
||||
Some("cached result")
|
||||
);
|
||||
|
||||
let statuses = client.get_server_status().await;
|
||||
assert!(statuses.iter().any(|(name, health)| name == "primary"
|
||||
&& !matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||
assert!(statuses.iter().any(|(name, health)| name == "cache"
|
||||
&& matches!(health, owlen_core::mcp::failover::ServerHealth::Healthy)));
|
||||
}
|
||||
@@ -44,6 +44,7 @@ async fn test_render_prompt_via_external_server() -> Result<()> {
|
||||
args: Vec::new(),
|
||||
transport: "stdio".into(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
let client = match RemoteMcpClient::new_with_config(&config) {
|
||||
|
||||
10
crates/owlen-markdown/Cargo.toml
Normal file
10
crates/owlen-markdown/Cargo.toml
Normal file
@@ -0,0 +1,10 @@
|
||||
[package]
|
||||
name = "owlen-markdown"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Lightweight markdown to ratatui::Text renderer for OWLEN"
|
||||
|
||||
[dependencies]
|
||||
ratatui = { workspace = true }
|
||||
unicode-width = "0.1"
|
||||
270
crates/owlen-markdown/src/lib.rs
Normal file
270
crates/owlen-markdown/src/lib.rs
Normal file
@@ -0,0 +1,270 @@
|
||||
use ratatui::prelude::*;
|
||||
use ratatui::text::{Line, Span, Text};
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
|
||||
/// Convert a markdown string into a `ratatui::Text`.
|
||||
///
|
||||
/// This lightweight renderer supports common constructs (headings, lists, bold,
|
||||
/// italics, and inline code) and is designed to keep dependencies minimal for
|
||||
/// the OWLEN project.
|
||||
pub fn from_str(input: &str) -> Text<'static> {
|
||||
let mut lines = Vec::new();
|
||||
let mut in_code_block = false;
|
||||
|
||||
for raw_line in input.lines() {
|
||||
let line = raw_line.trim_end_matches('\r');
|
||||
let trimmed = line.trim_start();
|
||||
let indent = &line[..line.len() - trimmed.len()];
|
||||
|
||||
if trimmed.starts_with("```") {
|
||||
in_code_block = !in_code_block;
|
||||
continue;
|
||||
}
|
||||
|
||||
if in_code_block {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
trimmed.to_string(),
|
||||
Style::default()
|
||||
.fg(Color::LightYellow)
|
||||
.add_modifier(Modifier::DIM),
|
||||
));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if trimmed.is_empty() {
|
||||
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||
continue;
|
||||
}
|
||||
|
||||
if trimmed.starts_with('#') {
|
||||
let level = trimmed.chars().take_while(|c| *c == '#').count().min(6);
|
||||
let content = trimmed[level..].trim_start();
|
||||
let mut style = Style::default().add_modifier(Modifier::BOLD);
|
||||
style = match level {
|
||||
1 => style.fg(Color::LightCyan),
|
||||
2 => style.fg(Color::Cyan),
|
||||
_ => style.fg(Color::LightBlue),
|
||||
};
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(content.to_string(), style));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = trimmed.strip_prefix("- ") {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
"• ".to_string(),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(rest) = trimmed.strip_prefix("* ") {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
"• ".to_string(),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some((number, rest)) = parse_ordered_item(trimmed) {
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.push(Span::styled(
|
||||
format!("{number}. "),
|
||||
Style::default().fg(Color::LightGreen),
|
||||
));
|
||||
spans.extend(parse_inline(rest));
|
||||
lines.push(Line::from(spans));
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut spans = Vec::new();
|
||||
if !indent.is_empty() {
|
||||
spans.push(Span::raw(indent.to_string()));
|
||||
}
|
||||
spans.extend(parse_inline(trimmed));
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
if input.is_empty() {
|
||||
lines.push(Line::from(Vec::<Span<'static>>::new()));
|
||||
}
|
||||
|
||||
Text::from(lines)
|
||||
}
|
||||
|
||||
fn parse_ordered_item(line: &str) -> Option<(u32, &str)> {
|
||||
let mut parts = line.splitn(2, '.');
|
||||
let number = parts.next()?.trim();
|
||||
let rest = parts.next()?;
|
||||
if number.chars().all(|c| c.is_ascii_digit()) {
|
||||
let value = number.parse().ok()?;
|
||||
let rest = rest.trim_start();
|
||||
Some((value, rest))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_inline(text: &str) -> Vec<Span<'static>> {
|
||||
let mut spans = Vec::new();
|
||||
let bytes = text.as_bytes();
|
||||
let mut i = 0;
|
||||
let len = bytes.len();
|
||||
let mut plain_start = 0;
|
||||
|
||||
while i < len {
|
||||
if bytes[i] == b'`' {
|
||||
if let Some(offset) = text[i + 1..].find('`') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default()
|
||||
.fg(Color::LightYellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if bytes[i] == b'*' {
|
||||
if i + 1 < len && bytes[i + 1] == b'*' {
|
||||
if let Some(offset) = text[i + 2..].find("**") {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 2..i + 2 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 4;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
} else if let Some(offset) = text[i + 1..].find('*') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::ITALIC),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if bytes[i] == b'_' {
|
||||
if i + 1 < len && bytes[i + 1] == b'_' {
|
||||
if let Some(offset) = text[i + 2..].find("__") {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 2..i + 2 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
));
|
||||
i += offset + 4;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
} else if let Some(offset) = text[i + 1..].find('_') {
|
||||
if i > plain_start {
|
||||
spans.push(Span::raw(text[plain_start..i].to_string()));
|
||||
}
|
||||
let content = &text[i + 1..i + 1 + offset];
|
||||
spans.push(Span::styled(
|
||||
content.to_string(),
|
||||
Style::default().add_modifier(Modifier::ITALIC),
|
||||
));
|
||||
i += offset + 2;
|
||||
plain_start = i;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
if plain_start < len {
|
||||
spans.push(Span::raw(text[plain_start..].to_string()));
|
||||
}
|
||||
|
||||
if spans.is_empty() {
|
||||
spans.push(Span::raw(String::new()));
|
||||
}
|
||||
|
||||
spans
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn visual_length(spans: &[Span<'_>]) -> usize {
|
||||
spans
|
||||
.iter()
|
||||
.map(|span| UnicodeWidthStr::width(span.content.as_ref()))
|
||||
.sum()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn headings_are_bold() {
|
||||
let text = from_str("# Heading");
|
||||
assert_eq!(text.lines.len(), 1);
|
||||
let line = &text.lines[0];
|
||||
assert!(
|
||||
line.spans
|
||||
.iter()
|
||||
.any(|span| span.style.contains(Modifier::BOLD))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn inline_code_styled() {
|
||||
let text = from_str("Use `code` inline.");
|
||||
let styled = text
|
||||
.lines
|
||||
.iter()
|
||||
.flat_map(|line| &line.spans)
|
||||
.find(|span| span.content.as_ref() == "code")
|
||||
.cloned()
|
||||
.unwrap();
|
||||
assert!(styled.style.contains(Modifier::BOLD));
|
||||
}
|
||||
}
|
||||
20
crates/owlen-providers/Cargo.toml
Normal file
20
crates/owlen-providers/Cargo.toml
Normal file
@@ -0,0 +1,20 @@
|
||||
[package]
|
||||
name = "owlen-providers"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
description = "Provider implementations for OWLEN"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
reqwest = { package = "reqwest", version = "0.11", features = ["json", "stream"] }
|
||||
3
crates/owlen-providers/src/lib.rs
Normal file
3
crates/owlen-providers/src/lib.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Provider implementations for OWLEN.
|
||||
|
||||
pub mod ollama;
|
||||
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
108
crates/owlen-providers/src/ollama/cloud.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::{env, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::{
|
||||
Error as CoreError, Result as CoreResult,
|
||||
config::OLLAMA_CLOUD_BASE_URL,
|
||||
provider::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
},
|
||||
};
|
||||
use serde_json::{Number, Value};
|
||||
|
||||
use super::OllamaClient;
|
||||
|
||||
const API_KEY_ENV: &str = "OLLAMA_CLOUD_API_KEY";
|
||||
|
||||
/// ModelProvider implementation for the hosted Ollama Cloud service.
|
||||
pub struct OllamaCloudProvider {
|
||||
client: OllamaClient,
|
||||
}
|
||||
|
||||
impl OllamaCloudProvider {
|
||||
/// Construct a new cloud provider. An API key must be supplied either
|
||||
/// directly or via the `OLLAMA_CLOUD_API_KEY` environment variable.
|
||||
pub fn new(
|
||||
base_url: Option<String>,
|
||||
api_key: Option<String>,
|
||||
request_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let (api_key, key_source) = resolve_api_key(api_key)?;
|
||||
let base_url = base_url.unwrap_or_else(|| OLLAMA_CLOUD_BASE_URL.to_string());
|
||||
|
||||
let mut metadata =
|
||||
ProviderMetadata::new("ollama_cloud", "Ollama (Cloud)", ProviderType::Cloud, true);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
metadata.metadata.insert(
|
||||
"api_key_source".into(),
|
||||
Value::String(key_source.to_string()),
|
||||
);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("api_key_env".into(), Value::String(API_KEY_ENV.to_string()));
|
||||
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
let client = OllamaClient::new(&base_url, Some(api_key), metadata, request_timeout)?;
|
||||
|
||||
Ok(Self { client })
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for OllamaCloudProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
self.client.metadata()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
match self.client.health_check().await {
|
||||
Ok(status) => Ok(status),
|
||||
Err(CoreError::Auth(_)) => Ok(ProviderStatus::RequiresSetup),
|
||||
Err(err) => Err(err),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
self.client.list_models().await
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
self.client.generate_stream(request).await
|
||||
}
|
||||
}
|
||||
|
||||
fn resolve_api_key(api_key: Option<String>) -> CoreResult<(String, &'static str)> {
|
||||
let key_from_config = api_key
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_string);
|
||||
|
||||
if let Some(key) = key_from_config {
|
||||
return Ok((key, "config"));
|
||||
}
|
||||
|
||||
let key_from_env = env::var(API_KEY_ENV)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty());
|
||||
|
||||
if let Some(key) = key_from_env {
|
||||
return Ok((key, "env"));
|
||||
}
|
||||
|
||||
Err(CoreError::Config(
|
||||
"Ollama Cloud API key not configured. Set OLLAMA_CLOUD_API_KEY or configure an API key."
|
||||
.into(),
|
||||
))
|
||||
}
|
||||
80
crates/owlen-providers/src/ollama/local.rs
Normal file
80
crates/owlen-providers/src/ollama/local.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::provider::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata, ProviderStatus,
|
||||
ProviderType,
|
||||
};
|
||||
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||
use serde_json::{Number, Value};
|
||||
use tokio::time::timeout;
|
||||
|
||||
use super::OllamaClient;
|
||||
|
||||
const DEFAULT_BASE_URL: &str = "http://localhost:11434";
|
||||
const DEFAULT_HEALTH_TIMEOUT_SECS: u64 = 5;
|
||||
|
||||
/// ModelProvider implementation for a local Ollama daemon.
|
||||
pub struct OllamaLocalProvider {
|
||||
client: OllamaClient,
|
||||
health_timeout: Duration,
|
||||
}
|
||||
|
||||
impl OllamaLocalProvider {
|
||||
/// Construct a new local provider using the shared [`OllamaClient`].
|
||||
pub fn new(
|
||||
base_url: Option<String>,
|
||||
request_timeout: Option<Duration>,
|
||||
health_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let base_url = base_url.unwrap_or_else(|| DEFAULT_BASE_URL.to_string());
|
||||
let health_timeout =
|
||||
health_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_HEALTH_TIMEOUT_SECS));
|
||||
|
||||
let mut metadata =
|
||||
ProviderMetadata::new("ollama_local", "Ollama (Local)", ProviderType::Local, false);
|
||||
metadata
|
||||
.metadata
|
||||
.insert("base_url".into(), Value::String(base_url.clone()));
|
||||
if let Some(timeout) = request_timeout {
|
||||
let timeout_ms = timeout.as_millis().min(u128::from(u64::MAX)) as u64;
|
||||
metadata.metadata.insert(
|
||||
"request_timeout_ms".into(),
|
||||
Value::Number(Number::from(timeout_ms)),
|
||||
);
|
||||
}
|
||||
|
||||
let client = OllamaClient::new(&base_url, None, metadata, request_timeout)?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
health_timeout,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for OllamaLocalProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
self.client.metadata()
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
match timeout(self.health_timeout, self.client.health_check()).await {
|
||||
Ok(Ok(status)) => Ok(status),
|
||||
Ok(Err(CoreError::Network(_))) | Ok(Err(CoreError::Timeout(_))) => {
|
||||
Ok(ProviderStatus::Unavailable)
|
||||
}
|
||||
Ok(Err(err)) => Err(err),
|
||||
Err(_) => Ok(ProviderStatus::Unavailable),
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
self.client.list_models().await
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
self.client.generate_stream(request).await
|
||||
}
|
||||
}
|
||||
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
7
crates/owlen-providers/src/ollama/mod.rs
Normal file
@@ -0,0 +1,7 @@
|
||||
pub mod cloud;
|
||||
pub mod local;
|
||||
pub mod shared;
|
||||
|
||||
pub use cloud::OllamaCloudProvider;
|
||||
pub use local::OllamaLocalProvider;
|
||||
pub use shared::OllamaClient;
|
||||
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
389
crates/owlen-providers/src/ollama/shared.rs
Normal file
@@ -0,0 +1,389 @@
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::StreamExt;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ProviderMetadata, ProviderStatus,
|
||||
};
|
||||
use owlen_core::{Error as CoreError, Result as CoreResult};
|
||||
use reqwest::{Client, Method, StatusCode, Url};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Map as JsonMap, Value};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 60;
|
||||
|
||||
/// Shared Ollama HTTP client used by both local and cloud providers.
|
||||
#[derive(Clone)]
|
||||
pub struct OllamaClient {
|
||||
http: Client,
|
||||
base_url: Url,
|
||||
api_key: Option<String>,
|
||||
provider_metadata: ProviderMetadata,
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
/// Create a new client with the given base URL and optional API key.
|
||||
pub fn new(
|
||||
base_url: impl AsRef<str>,
|
||||
api_key: Option<String>,
|
||||
provider_metadata: ProviderMetadata,
|
||||
request_timeout: Option<Duration>,
|
||||
) -> CoreResult<Self> {
|
||||
let base_url = Url::parse(base_url.as_ref())
|
||||
.map_err(|err| CoreError::Config(format!("invalid base url: {}", err)))?;
|
||||
|
||||
let timeout = request_timeout.unwrap_or_else(|| Duration::from_secs(DEFAULT_TIMEOUT_SECS));
|
||||
let http = Client::builder()
|
||||
.timeout(timeout)
|
||||
.build()
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
Ok(Self {
|
||||
http,
|
||||
base_url,
|
||||
api_key,
|
||||
provider_metadata,
|
||||
})
|
||||
}
|
||||
|
||||
/// Provider metadata associated with this client.
|
||||
pub fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.provider_metadata
|
||||
}
|
||||
|
||||
/// Perform a basic health check to determine provider availability.
|
||||
pub async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
let url = self.endpoint("api/tags")?;
|
||||
|
||||
let response = self
|
||||
.request(Method::GET, url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
match response.status() {
|
||||
status if status.is_success() => Ok(ProviderStatus::Available),
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => Ok(ProviderStatus::RequiresSetup),
|
||||
_ => Ok(ProviderStatus::Unavailable),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the available models from the Ollama API.
|
||||
pub async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
let url = self.endpoint("api/tags")?;
|
||||
|
||||
let response = self
|
||||
.request(Method::GET, url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
let status = response.status();
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
|
||||
if !status.is_success() {
|
||||
return Err(map_http_error("tags", status, &bytes));
|
||||
}
|
||||
|
||||
let payload: TagsResponse =
|
||||
serde_json::from_slice(&bytes).map_err(CoreError::Serialization)?;
|
||||
|
||||
let models = payload
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|model| self.parse_model_info(model))
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Request a streaming generation session from Ollama.
|
||||
pub async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
let url = self.endpoint("api/generate")?;
|
||||
|
||||
let body = self.build_generate_body(request);
|
||||
|
||||
let response = self
|
||||
.request(Method::POST, url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(map_reqwest_error)?;
|
||||
|
||||
let status = response.status();
|
||||
|
||||
if !status.is_success() {
|
||||
let bytes = response.bytes().await.map_err(map_reqwest_error)?;
|
||||
return Err(map_http_error("generate", status, &bytes));
|
||||
}
|
||||
|
||||
let stream = response.bytes_stream();
|
||||
let (tx, rx) = mpsc::channel::<CoreResult<GenerateChunk>>(32);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
let mut buffer: Vec<u8> = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
buffer.extend_from_slice(&bytes);
|
||||
while let Some(pos) = buffer.iter().position(|byte| *byte == b'\n') {
|
||||
let line_bytes: Vec<u8> = buffer.drain(..=pos).collect();
|
||||
let line = String::from_utf8_lossy(&line_bytes).trim().to_string();
|
||||
if line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
match parse_stream_line(&line) {
|
||||
Ok(item) => {
|
||||
if tx.send(Ok(item)).await.is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(err)).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(map_reqwest_error(err))).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !buffer.is_empty() {
|
||||
let line = String::from_utf8_lossy(&buffer).trim().to_string();
|
||||
if !line.is_empty() {
|
||||
match parse_stream_line(&line) {
|
||||
Ok(item) => {
|
||||
let _ = tx.send(Ok(item)).await;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = tx.send(Err(err)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = ReceiverStream::new(rx);
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
fn request(&self, method: Method, url: Url) -> reqwest::RequestBuilder {
|
||||
let mut builder = self.http.request(method, url);
|
||||
if let Some(api_key) = &self.api_key {
|
||||
builder = builder.bearer_auth(api_key);
|
||||
}
|
||||
builder
|
||||
}
|
||||
|
||||
fn endpoint(&self, path: &str) -> CoreResult<Url> {
|
||||
self.base_url
|
||||
.join(path)
|
||||
.map_err(|err| CoreError::Config(format!("invalid endpoint '{}': {}", path, err)))
|
||||
}
|
||||
|
||||
fn build_generate_body(&self, request: GenerateRequest) -> Value {
|
||||
let GenerateRequest {
|
||||
model,
|
||||
prompt,
|
||||
context,
|
||||
parameters,
|
||||
metadata,
|
||||
} = request;
|
||||
|
||||
let mut body = JsonMap::new();
|
||||
body.insert("model".into(), Value::String(model));
|
||||
body.insert("stream".into(), Value::Bool(true));
|
||||
|
||||
if let Some(prompt) = prompt {
|
||||
body.insert("prompt".into(), Value::String(prompt));
|
||||
}
|
||||
|
||||
if !context.is_empty() {
|
||||
let items = context.into_iter().map(Value::String).collect();
|
||||
body.insert("context".into(), Value::Array(items));
|
||||
}
|
||||
|
||||
if !parameters.is_empty() {
|
||||
body.insert("options".into(), Value::Object(to_json_map(parameters)));
|
||||
}
|
||||
|
||||
if !metadata.is_empty() {
|
||||
body.insert("metadata".into(), Value::Object(to_json_map(metadata)));
|
||||
}
|
||||
|
||||
Value::Object(body)
|
||||
}
|
||||
|
||||
fn parse_model_info(&self, model: OllamaModel) -> ModelInfo {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Some(digest) = model.digest {
|
||||
metadata.insert("digest".to_string(), Value::String(digest));
|
||||
}
|
||||
if let Some(modified) = model.modified_at {
|
||||
metadata.insert("modified_at".to_string(), Value::String(modified));
|
||||
}
|
||||
if let Some(details) = model.details {
|
||||
let mut details_map = JsonMap::new();
|
||||
if let Some(format) = details.format {
|
||||
details_map.insert("format".into(), Value::String(format));
|
||||
}
|
||||
if let Some(family) = details.family {
|
||||
details_map.insert("family".into(), Value::String(family));
|
||||
}
|
||||
if let Some(parameter_size) = details.parameter_size {
|
||||
details_map.insert("parameter_size".into(), Value::String(parameter_size));
|
||||
}
|
||||
if let Some(quantisation) = details.quantization_level {
|
||||
details_map.insert("quantization_level".into(), Value::String(quantisation));
|
||||
}
|
||||
|
||||
if !details_map.is_empty() {
|
||||
metadata.insert("details".to_string(), Value::Object(details_map));
|
||||
}
|
||||
}
|
||||
|
||||
ModelInfo {
|
||||
name: model.name,
|
||||
size_bytes: model.size,
|
||||
capabilities: Vec::new(),
|
||||
description: None,
|
||||
provider: self.provider_metadata.clone(),
|
||||
metadata,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TagsResponse {
|
||||
#[serde(default)]
|
||||
models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModel {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
size: Option<u64>,
|
||||
#[serde(default)]
|
||||
digest: Option<String>,
|
||||
#[serde(default)]
|
||||
modified_at: Option<String>,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
format: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
#[serde(default)]
|
||||
quantization_level: Option<String>,
|
||||
}
|
||||
|
||||
fn to_json_map(source: HashMap<String, Value>) -> JsonMap<String, Value> {
|
||||
source.into_iter().collect()
|
||||
}
|
||||
|
||||
fn to_metadata_map(value: &Value) -> HashMap<String, Value> {
|
||||
let mut metadata = HashMap::new();
|
||||
|
||||
if let Value::Object(obj) = value {
|
||||
for (key, item) in obj {
|
||||
if key == "response" || key == "done" {
|
||||
continue;
|
||||
}
|
||||
metadata.insert(key.clone(), item.clone());
|
||||
}
|
||||
}
|
||||
|
||||
metadata
|
||||
}
|
||||
|
||||
fn parse_stream_line(line: &str) -> CoreResult<GenerateChunk> {
|
||||
let value: Value = serde_json::from_str(line).map_err(CoreError::Serialization)?;
|
||||
|
||||
if let Some(error) = value.get("error").and_then(Value::as_str) {
|
||||
return Err(CoreError::Provider(anyhow::anyhow!(
|
||||
"ollama generation error: {}",
|
||||
error
|
||||
)));
|
||||
}
|
||||
|
||||
let mut chunk = GenerateChunk {
|
||||
text: value
|
||||
.get("response")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::to_string),
|
||||
is_final: value.get("done").and_then(Value::as_bool).unwrap_or(false),
|
||||
metadata: to_metadata_map(&value),
|
||||
};
|
||||
|
||||
if chunk.is_final && chunk.text.is_none() && chunk.metadata.is_empty() {
|
||||
chunk
|
||||
.metadata
|
||||
.insert("status".into(), Value::String("done".into()));
|
||||
}
|
||||
|
||||
Ok(chunk)
|
||||
}
|
||||
|
||||
fn map_http_error(endpoint: &str, status: StatusCode, body: &[u8]) -> CoreError {
|
||||
match status {
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => CoreError::Auth(format!(
|
||||
"Ollama {} request unauthorized (status {})",
|
||||
endpoint, status
|
||||
)),
|
||||
StatusCode::TOO_MANY_REQUESTS => CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request rate limited (status {})",
|
||||
endpoint,
|
||||
status
|
||||
)),
|
||||
_ => {
|
||||
let snippet = truncated_body(body);
|
||||
CoreError::Provider(anyhow::anyhow!(
|
||||
"Ollama {} request failed: HTTP {} - {}",
|
||||
endpoint,
|
||||
status,
|
||||
snippet
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn truncated_body(body: &[u8]) -> String {
|
||||
const MAX_CHARS: usize = 512;
|
||||
let text = String::from_utf8_lossy(body);
|
||||
let mut value = String::new();
|
||||
for (idx, ch) in text.chars().enumerate() {
|
||||
if idx >= MAX_CHARS {
|
||||
value.push('…');
|
||||
return value;
|
||||
}
|
||||
value.push(ch);
|
||||
}
|
||||
value
|
||||
}
|
||||
|
||||
fn map_reqwest_error(err: reqwest::Error) -> CoreError {
|
||||
if err.is_timeout() {
|
||||
CoreError::Timeout(err.to_string())
|
||||
} else if err.is_connect() || err.is_request() {
|
||||
CoreError::Network(err.to_string())
|
||||
} else {
|
||||
CoreError::Provider(err.into())
|
||||
}
|
||||
}
|
||||
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
106
crates/owlen-providers/tests/common/mock_provider.rs
Normal file
@@ -0,0 +1,106 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::stream::{self, StreamExt};
|
||||
use owlen_core::Result as CoreResult;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
};
|
||||
|
||||
pub struct MockProvider {
|
||||
metadata: ProviderMetadata,
|
||||
models: Vec<ModelInfo>,
|
||||
status: ProviderStatus,
|
||||
#[allow(clippy::type_complexity)]
|
||||
generate_handler: Option<Arc<dyn Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync>>,
|
||||
generate_error: Option<Arc<dyn Fn() -> owlen_core::Error + Send + Sync>>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
pub fn new(id: &str) -> Self {
|
||||
let metadata = ProviderMetadata::new(
|
||||
id,
|
||||
format!("Mock Provider ({})", id),
|
||||
ProviderType::Local,
|
||||
false,
|
||||
);
|
||||
|
||||
Self {
|
||||
metadata,
|
||||
models: vec![ModelInfo {
|
||||
name: format!("{}-primary", id),
|
||||
size_bytes: None,
|
||||
capabilities: vec!["chat".into()],
|
||||
description: Some("Mock model".into()),
|
||||
provider: ProviderMetadata::new(id, "Mock", ProviderType::Local, false),
|
||||
metadata: Default::default(),
|
||||
}],
|
||||
status: ProviderStatus::Available,
|
||||
generate_handler: None,
|
||||
generate_error: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_models(mut self, models: Vec<ModelInfo>) -> Self {
|
||||
self.models = models;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_status(mut self, status: ProviderStatus) -> Self {
|
||||
self.status = status;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generate_handler<F>(mut self, handler: F) -> Self
|
||||
where
|
||||
F: Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync + 'static,
|
||||
{
|
||||
self.generate_handler = Some(Arc::new(handler));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_generate_error<F>(mut self, factory: F) -> Self
|
||||
where
|
||||
F: Fn() -> owlen_core::Error + Send + Sync + 'static,
|
||||
{
|
||||
self.generate_error = Some(Arc::new(factory));
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for MockProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
||||
Ok(self.status)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
||||
Ok(self.models.clone())
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
||||
if let Some(factory) = &self.generate_error {
|
||||
return Err(factory());
|
||||
}
|
||||
|
||||
let chunks = if let Some(handler) = &self.generate_handler {
|
||||
(handler)(request)
|
||||
} else {
|
||||
vec![GenerateChunk::final_chunk()]
|
||||
};
|
||||
|
||||
let stream = stream::iter(chunks.into_iter().map(Ok)).boxed();
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MockProvider> for Arc<dyn ModelProvider> {
|
||||
fn from(provider: MockProvider) -> Self {
|
||||
Arc::new(provider)
|
||||
}
|
||||
}
|
||||
1
crates/owlen-providers/tests/common/mod.rs
Normal file
1
crates/owlen-providers/tests/common/mod.rs
Normal file
@@ -0,0 +1 @@
|
||||
pub mod mock_provider;
|
||||
117
crates/owlen-providers/tests/integration_test.rs
Normal file
117
crates/owlen-providers/tests/integration_test.rs
Normal file
@@ -0,0 +1,117 @@
|
||||
mod common;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::StreamExt;
|
||||
|
||||
use common::mock_provider::MockProvider;
|
||||
use owlen_core::config::Config;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, ModelInfo, ProviderManager, ProviderType,
|
||||
};
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn base_config() -> Config {
|
||||
Config {
|
||||
providers: Default::default(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
fn make_model(name: &str, provider: &str) -> ModelInfo {
|
||||
ModelInfo {
|
||||
name: name.into(),
|
||||
size_bytes: None,
|
||||
capabilities: vec!["chat".into()],
|
||||
description: Some("mock".into()),
|
||||
provider: owlen_core::provider::ProviderMetadata::new(
|
||||
provider,
|
||||
provider,
|
||||
ProviderType::Local,
|
||||
false,
|
||||
),
|
||||
metadata: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn registers_providers_and_lists_ids() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider: Arc<dyn owlen_core::provider::ModelProvider> = MockProvider::new("mock-a").into();
|
||||
|
||||
manager.register_provider(provider).await;
|
||||
let ids = manager.provider_ids().await;
|
||||
|
||||
assert_eq!(ids, vec!["mock-a".to_string()]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregates_models_across_providers() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider_a = MockProvider::new("mock-a").with_models(vec![make_model("alpha", "mock-a")]);
|
||||
let provider_b = MockProvider::new("mock-b").with_models(vec![make_model("beta", "mock-b")]);
|
||||
|
||||
manager.register_provider(provider_a.into()).await;
|
||||
manager.register_provider(provider_b.into()).await;
|
||||
|
||||
let models = manager.list_all_models().await.unwrap();
|
||||
assert_eq!(models.len(), 2);
|
||||
assert!(models.iter().any(|m| m.model.name == "alpha"));
|
||||
assert!(models.iter().any(|m| m.model.name == "beta"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn routes_generation_to_specific_provider() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = MockProvider::new("mock-gen").with_generate_handler(|_req| {
|
||||
vec![
|
||||
GenerateChunk::from_text("hello"),
|
||||
GenerateChunk::final_chunk(),
|
||||
]
|
||||
});
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
|
||||
let request = GenerateRequest::new("mock-gen::primary");
|
||||
let mut stream = manager.generate("mock-gen", request).await.unwrap();
|
||||
let mut collected = Vec::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
collected.push(chunk.unwrap());
|
||||
}
|
||||
|
||||
assert_eq!(collected.len(), 2);
|
||||
assert_eq!(collected[0].text.as_deref(), Some("hello"));
|
||||
assert!(collected[1].is_final);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn marks_provider_unavailable_on_error() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = MockProvider::new("flaky")
|
||||
.with_generate_error(|| owlen_core::Error::Network("boom".into()));
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
let request = GenerateRequest::new("flaky::model");
|
||||
let result = manager.generate("flaky", request).await;
|
||||
assert!(result.is_err());
|
||||
|
||||
let status = manager.provider_status("flaky").await.unwrap();
|
||||
assert!(matches!(
|
||||
status,
|
||||
owlen_core::provider::ProviderStatus::Unavailable
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn health_refresh_updates_status_cache() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider =
|
||||
MockProvider::new("healthy").with_status(owlen_core::provider::ProviderStatus::Available);
|
||||
|
||||
manager.register_provider(provider.into()).await;
|
||||
let statuses = manager.refresh_health().await;
|
||||
assert_eq!(
|
||||
statuses.get("healthy"),
|
||||
Some(&owlen_core::provider::ProviderStatus::Available)
|
||||
);
|
||||
}
|
||||
@@ -20,6 +20,18 @@ textwrap = { workspace = true }
|
||||
unicode-width = "0.1"
|
||||
unicode-segmentation = "1.11"
|
||||
async-trait = "0.1"
|
||||
globset = "0.4"
|
||||
ignore = "0.4"
|
||||
pathdiff = "0.2"
|
||||
tree-sitter = "0.20"
|
||||
tree-sitter-rust = "0.20"
|
||||
dirs = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
syntect = "5.3"
|
||||
once_cell = "1.19"
|
||||
owlen-markdown = { path = "../owlen-markdown" }
|
||||
shellexpand = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
@@ -30,6 +42,9 @@ futures-util = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
serde_json.workspace = true
|
||||
serde.workspace = true
|
||||
chrono = { workspace = true }
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
99
crates/owlen-tui/keymap.toml
Normal file
99
crates/owlen-tui/keymap.toml
Normal file
@@ -0,0 +1,99 @@
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["m"]
|
||||
command = "model.open_all"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+L"]
|
||||
command = "model.open_local"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+C"]
|
||||
command = "model.open_cloud"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+Shift+P"]
|
||||
command = "model.open_available"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+P"]
|
||||
command = "palette.open"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["Ctrl+P"]
|
||||
command = "palette.open"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Tab"]
|
||||
command = "focus.next"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Shift+Tab"]
|
||||
command = "focus.prev"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+1"]
|
||||
command = "focus.files"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+2"]
|
||||
command = "focus.chat"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+3"]
|
||||
command = "focus.code"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+4"]
|
||||
command = "focus.thinking"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+5"]
|
||||
command = "focus.input"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["Enter"]
|
||||
command = "composer.submit"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["Ctrl+;"]
|
||||
command = "mode.command"
|
||||
|
||||
[[binding]]
|
||||
mode = "normal"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "editing"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "visual"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "command"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
|
||||
[[binding]]
|
||||
mode = "help"
|
||||
keys = ["F12"]
|
||||
command = "debug.toggle"
|
||||
77
crates/owlen-tui/src/app/generation.rs
Normal file
77
crates/owlen-tui/src/app/generation.rs
Normal file
@@ -0,0 +1,77 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures_util::StreamExt;
|
||||
use owlen_core::provider::GenerateRequest;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::{ActiveGeneration, App, AppMessage};
|
||||
|
||||
impl App {
|
||||
/// Kick off a new generation task on the supplied provider.
|
||||
pub fn start_generation(
|
||||
&mut self,
|
||||
provider_id: impl Into<String>,
|
||||
request: GenerateRequest,
|
||||
) -> Result<Uuid> {
|
||||
let provider_id = provider_id.into();
|
||||
let request_id = Uuid::new_v4();
|
||||
|
||||
// Cancel any existing task so we don't interleave output.
|
||||
if let Some(active) = self.active_generation.take() {
|
||||
active.abort();
|
||||
}
|
||||
|
||||
self.message_tx
|
||||
.send(AppMessage::GenerateStart {
|
||||
request_id,
|
||||
provider_id: provider_id.clone(),
|
||||
request: request.clone(),
|
||||
})
|
||||
.map_err(|err| anyhow!("failed to queue generation start: {err:?}"))?;
|
||||
|
||||
let manager = Arc::clone(&self.provider_manager);
|
||||
let message_tx = self.message_tx.clone();
|
||||
let provider_for_task = provider_id.clone();
|
||||
|
||||
let join_handle = tokio::spawn(async move {
|
||||
let mut stream = match manager.generate(&provider_for_task, request).await {
|
||||
Ok(stream) => stream,
|
||||
Err(err) => {
|
||||
let _ = message_tx.send(AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: err.to_string(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
if message_tx
|
||||
.send(AppMessage::GenerateChunk { request_id, chunk })
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = message_tx.send(AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: err.to_string(),
|
||||
});
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = message_tx.send(AppMessage::GenerateComplete { request_id });
|
||||
});
|
||||
|
||||
let generation = ActiveGeneration::new(request_id, provider_id, join_handle);
|
||||
self.active_generation = Some(generation);
|
||||
|
||||
Ok(request_id)
|
||||
}
|
||||
}
|
||||
135
crates/owlen-tui/src/app/handler.rs
Normal file
135
crates/owlen-tui/src/app/handler.rs
Normal file
@@ -0,0 +1,135 @@
|
||||
use super::{App, messages::AppMessage};
|
||||
use log::warn;
|
||||
use owlen_core::{
|
||||
provider::{GenerateChunk, GenerateRequest, ProviderStatus},
|
||||
state::AppState,
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Trait implemented by UI state containers to react to [`AppMessage`] events.
|
||||
pub trait MessageState {
|
||||
/// Called when a generation request is about to start.
|
||||
#[allow(unused_variables)]
|
||||
fn start_generation(
|
||||
&mut self,
|
||||
request_id: Uuid,
|
||||
provider_id: &str,
|
||||
request: &GenerateRequest,
|
||||
) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called for every streamed generation chunk.
|
||||
#[allow(unused_variables)]
|
||||
fn append_chunk(&mut self, request_id: Uuid, chunk: &GenerateChunk) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a generation finishes successfully.
|
||||
#[allow(unused_variables)]
|
||||
fn generation_complete(&mut self, request_id: Uuid) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a generation fails.
|
||||
#[allow(unused_variables)]
|
||||
fn generation_failed(&mut self, request_id: Option<Uuid>, message: &str) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when refreshed model metadata is available.
|
||||
fn update_model_list(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a models refresh has been requested.
|
||||
fn refresh_model_list(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when provider status updates arrive.
|
||||
#[allow(unused_variables)]
|
||||
fn update_provider_status(&mut self, provider_id: &str, status: ProviderStatus) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called when a resize event occurs.
|
||||
#[allow(unused_variables)]
|
||||
fn handle_resize(&mut self, width: u16, height: u16) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
/// Called on periodic ticks.
|
||||
fn handle_tick(&mut self) -> AppState {
|
||||
AppState::Running
|
||||
}
|
||||
}
|
||||
|
||||
impl App {
|
||||
/// Dispatch a message to the provided [`MessageState`]. Returns `true` when the
|
||||
/// state indicates the UI should exit.
|
||||
pub fn handle_message<State>(&mut self, state: &mut State, message: AppMessage) -> bool
|
||||
where
|
||||
State: MessageState,
|
||||
{
|
||||
use AppMessage::*;
|
||||
|
||||
let outcome = match message {
|
||||
KeyPress(_) => AppState::Running,
|
||||
Resize { width, height } => state.handle_resize(width, height),
|
||||
Tick => state.handle_tick(),
|
||||
GenerateStart {
|
||||
request_id,
|
||||
provider_id,
|
||||
request,
|
||||
} => state.start_generation(request_id, &provider_id, &request),
|
||||
GenerateChunk { request_id, chunk } => state.append_chunk(request_id, &chunk),
|
||||
GenerateComplete { request_id } => {
|
||||
self.clear_active_generation(request_id);
|
||||
state.generation_complete(request_id)
|
||||
}
|
||||
GenerateError {
|
||||
request_id,
|
||||
message,
|
||||
} => {
|
||||
self.clear_active_generation_optional(request_id);
|
||||
state.generation_failed(request_id, &message)
|
||||
}
|
||||
ModelsRefresh => state.refresh_model_list(),
|
||||
ModelsUpdated => state.update_model_list(),
|
||||
ProviderStatus {
|
||||
provider_id,
|
||||
status,
|
||||
} => state.update_provider_status(&provider_id, status),
|
||||
};
|
||||
|
||||
matches!(outcome, AppState::Quit)
|
||||
}
|
||||
|
||||
fn clear_active_generation(&mut self, request_id: Uuid) {
|
||||
if self
|
||||
.active_generation
|
||||
.as_ref()
|
||||
.map(|active| active.request_id() == request_id)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
self.active_generation = None;
|
||||
} else {
|
||||
warn!(
|
||||
"received completion for unknown request {}, ignoring",
|
||||
request_id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn clear_active_generation_optional(&mut self, request_id: Option<Uuid>) {
|
||||
match request_id {
|
||||
Some(id) => self.clear_active_generation(id),
|
||||
None => {
|
||||
if self.active_generation.is_some() {
|
||||
self.active_generation = None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
41
crates/owlen-tui/src/app/messages.rs
Normal file
41
crates/owlen-tui/src/app/messages.rs
Normal file
@@ -0,0 +1,41 @@
|
||||
use crossterm::event::KeyEvent;
|
||||
use owlen_core::provider::{GenerateChunk, GenerateRequest, ProviderStatus};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Messages exchanged between the UI event loop and background workers.
|
||||
#[derive(Debug)]
|
||||
pub enum AppMessage {
|
||||
/// User input event bubbled up from the terminal layer.
|
||||
KeyPress(KeyEvent),
|
||||
/// Terminal resize notification.
|
||||
Resize { width: u16, height: u16 },
|
||||
/// Periodic tick used to drive animations.
|
||||
Tick,
|
||||
/// Initiate a new text generation request.
|
||||
GenerateStart {
|
||||
request_id: Uuid,
|
||||
provider_id: String,
|
||||
request: GenerateRequest,
|
||||
},
|
||||
/// Streamed response chunk from the active generation task.
|
||||
GenerateChunk {
|
||||
request_id: Uuid,
|
||||
chunk: GenerateChunk,
|
||||
},
|
||||
/// Generation finished successfully.
|
||||
GenerateComplete { request_id: Uuid },
|
||||
/// Generation failed or was aborted.
|
||||
GenerateError {
|
||||
request_id: Option<Uuid>,
|
||||
message: String,
|
||||
},
|
||||
/// Trigger a background refresh of available models.
|
||||
ModelsRefresh,
|
||||
/// New model list data is ready.
|
||||
ModelsUpdated,
|
||||
/// Provider health status update.
|
||||
ProviderStatus {
|
||||
provider_id: String,
|
||||
status: ProviderStatus,
|
||||
},
|
||||
}
|
||||
240
crates/owlen-tui/src/app/mod.rs
Normal file
240
crates/owlen-tui/src/app/mod.rs
Normal file
@@ -0,0 +1,240 @@
|
||||
mod generation;
|
||||
mod handler;
|
||||
pub mod mvu;
|
||||
mod worker;
|
||||
|
||||
pub mod messages;
|
||||
pub use worker::background_worker;
|
||||
|
||||
use std::{
|
||||
io,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use crossterm::event::{self, KeyEventKind};
|
||||
use owlen_core::{provider::ProviderManager, state::AppState};
|
||||
use ratatui::{Terminal, backend::CrosstermBackend};
|
||||
use tokio::{
|
||||
sync::mpsc::{self, error::TryRecvError},
|
||||
task::{AbortHandle, JoinHandle, yield_now},
|
||||
};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{Event, SessionEvent, events};
|
||||
|
||||
pub use handler::MessageState;
|
||||
pub use messages::AppMessage;
|
||||
|
||||
#[async_trait]
|
||||
pub trait UiRuntime: MessageState {
|
||||
async fn handle_ui_event(&mut self, event: Event) -> Result<AppState>;
|
||||
async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()>;
|
||||
async fn process_pending_llm_request(&mut self) -> Result<()>;
|
||||
async fn process_pending_tool_execution(&mut self) -> Result<()>;
|
||||
fn poll_controller_events(&mut self) -> Result<()>;
|
||||
fn advance_loading_animation(&mut self);
|
||||
fn streaming_count(&self) -> usize;
|
||||
}
|
||||
|
||||
/// High-level application state driving the non-blocking TUI.
|
||||
pub struct App {
|
||||
provider_manager: Arc<ProviderManager>,
|
||||
message_tx: mpsc::UnboundedSender<AppMessage>,
|
||||
message_rx: Option<mpsc::UnboundedReceiver<AppMessage>>,
|
||||
active_generation: Option<ActiveGeneration>,
|
||||
}
|
||||
|
||||
impl App {
|
||||
/// Construct a new application instance with an associated message channel.
|
||||
pub fn new(provider_manager: Arc<ProviderManager>) -> Self {
|
||||
let (message_tx, message_rx) = mpsc::unbounded_channel();
|
||||
|
||||
Self {
|
||||
provider_manager,
|
||||
message_tx,
|
||||
message_rx: Some(message_rx),
|
||||
active_generation: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Cloneable sender handle for pushing messages into the application loop.
|
||||
pub fn message_sender(&self) -> mpsc::UnboundedSender<AppMessage> {
|
||||
self.message_tx.clone()
|
||||
}
|
||||
|
||||
/// Whether a generation task is currently in flight.
|
||||
pub fn has_active_generation(&self) -> bool {
|
||||
self.active_generation.is_some()
|
||||
}
|
||||
|
||||
/// Abort any in-flight generation task.
|
||||
pub fn abort_active_generation(&mut self) {
|
||||
if let Some(active) = self.active_generation.take() {
|
||||
active.abort();
|
||||
}
|
||||
}
|
||||
|
||||
/// Launch the background worker responsible for provider health checks.
|
||||
pub fn spawn_background_worker(&self) -> JoinHandle<()> {
|
||||
let manager = Arc::clone(&self.provider_manager);
|
||||
let sender = self.message_tx.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
worker::background_worker(manager, sender).await;
|
||||
})
|
||||
}
|
||||
|
||||
/// Drive the main UI loop, handling terminal events, background messages, and
|
||||
/// provider status updates without blocking rendering.
|
||||
pub async fn run<State, RenderFn>(
|
||||
&mut self,
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut State,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
mut render: RenderFn,
|
||||
) -> Result<AppState>
|
||||
where
|
||||
State: UiRuntime,
|
||||
RenderFn: FnMut(&mut Terminal<CrosstermBackend<io::Stdout>>, &mut State) -> Result<()>,
|
||||
{
|
||||
let mut message_rx = self
|
||||
.message_rx
|
||||
.take()
|
||||
.expect("App::run called without an available message receiver");
|
||||
|
||||
let poll_interval = Duration::from_millis(16);
|
||||
let mut last_frame = Instant::now();
|
||||
let frame_interval = Duration::from_millis(16);
|
||||
|
||||
let mut worker_handle = Some(self.spawn_background_worker());
|
||||
|
||||
let exit_state = AppState::Quit;
|
||||
'main: loop {
|
||||
state.advance_loading_animation();
|
||||
|
||||
state.process_pending_llm_request().await?;
|
||||
state.process_pending_tool_execution().await?;
|
||||
state.poll_controller_events()?;
|
||||
|
||||
loop {
|
||||
match session_rx.try_recv() {
|
||||
Ok(session_event) => {
|
||||
state.handle_session_event(session_event).await?;
|
||||
}
|
||||
Err(TryRecvError::Empty) => break,
|
||||
Err(TryRecvError::Disconnected) => {
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
loop {
|
||||
match message_rx.try_recv() {
|
||||
Ok(message) => {
|
||||
if self.handle_message(state, message) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => break,
|
||||
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => break,
|
||||
}
|
||||
}
|
||||
|
||||
if last_frame.elapsed() >= frame_interval {
|
||||
render(terminal, state)?;
|
||||
last_frame = Instant::now();
|
||||
}
|
||||
|
||||
if self.handle_message(state, AppMessage::Tick) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
|
||||
match event::poll(poll_interval) {
|
||||
Ok(true) => match event::read() {
|
||||
Ok(raw_event) => {
|
||||
if let Some(ui_event) = events::from_crossterm_event(raw_event) {
|
||||
if let Event::Key(key) = &ui_event {
|
||||
if key.kind == KeyEventKind::Press {
|
||||
let _ = self.message_tx.send(AppMessage::KeyPress(*key));
|
||||
}
|
||||
} else if let Event::Resize(width, height) = &ui_event {
|
||||
let _ = self.message_tx.send(AppMessage::Resize {
|
||||
width: *width,
|
||||
height: *height,
|
||||
});
|
||||
}
|
||||
|
||||
if matches!(state.handle_ui_event(ui_event).await?, AppState::Quit) {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
break 'main;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
},
|
||||
Ok(false) => {}
|
||||
Err(err) => {
|
||||
if let Some(handle) = worker_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
|
||||
yield_now().await;
|
||||
}
|
||||
|
||||
if let Some(handle) = worker_handle {
|
||||
handle.abort();
|
||||
}
|
||||
|
||||
self.message_rx = Some(message_rx);
|
||||
|
||||
Ok(exit_state)
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveGeneration {
|
||||
request_id: Uuid,
|
||||
#[allow(dead_code)]
|
||||
provider_id: String,
|
||||
abort_handle: AbortHandle,
|
||||
#[allow(dead_code)]
|
||||
join_handle: JoinHandle<()>,
|
||||
}
|
||||
|
||||
impl ActiveGeneration {
|
||||
fn new(request_id: Uuid, provider_id: String, join_handle: JoinHandle<()>) -> Self {
|
||||
let abort_handle = join_handle.abort_handle();
|
||||
Self {
|
||||
request_id,
|
||||
provider_id,
|
||||
abort_handle,
|
||||
join_handle,
|
||||
}
|
||||
}
|
||||
|
||||
fn abort(self) {
|
||||
self.abort_handle.abort();
|
||||
}
|
||||
|
||||
fn request_id(&self) -> Uuid {
|
||||
self.request_id
|
||||
}
|
||||
}
|
||||
165
crates/owlen-tui/src/app/mvu.rs
Normal file
165
crates/owlen-tui/src/app/mvu.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use owlen_core::{consent::ConsentScope, ui::InputMode};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AppModel {
|
||||
pub composer: ComposerModel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComposerModel {
|
||||
pub draft: String,
|
||||
pub pending_submit: bool,
|
||||
pub mode: InputMode,
|
||||
}
|
||||
|
||||
impl Default for ComposerModel {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
draft: String::new(),
|
||||
pending_submit: false,
|
||||
mode: InputMode::Normal,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEvent {
|
||||
Composer(ComposerEvent),
|
||||
ToolPermission {
|
||||
request_id: Uuid,
|
||||
scope: ConsentScope,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ComposerEvent {
|
||||
DraftChanged { content: String },
|
||||
ModeChanged { mode: InputMode },
|
||||
Submit,
|
||||
SubmissionHandled { result: SubmissionOutcome },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum SubmissionOutcome {
|
||||
MessageSent,
|
||||
CommandExecuted,
|
||||
Failed,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEffect {
|
||||
SetStatus(String),
|
||||
RequestSubmit,
|
||||
ResolveToolConsent {
|
||||
request_id: Uuid,
|
||||
scope: ConsentScope,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn update(model: &mut AppModel, event: AppEvent) -> Vec<AppEffect> {
|
||||
match event {
|
||||
AppEvent::Composer(event) => update_composer(&mut model.composer, event),
|
||||
AppEvent::ToolPermission { request_id, scope } => {
|
||||
vec![AppEffect::ResolveToolConsent { request_id, scope }]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_composer(model: &mut ComposerModel, event: ComposerEvent) -> Vec<AppEffect> {
|
||||
match event {
|
||||
ComposerEvent::DraftChanged { content } => {
|
||||
model.draft = content;
|
||||
Vec::new()
|
||||
}
|
||||
ComposerEvent::ModeChanged { mode } => {
|
||||
model.mode = mode;
|
||||
Vec::new()
|
||||
}
|
||||
ComposerEvent::Submit => {
|
||||
if model.draft.trim().is_empty() {
|
||||
return vec![AppEffect::SetStatus(
|
||||
"Cannot send empty message".to_string(),
|
||||
)];
|
||||
}
|
||||
|
||||
model.pending_submit = true;
|
||||
vec![AppEffect::RequestSubmit]
|
||||
}
|
||||
ComposerEvent::SubmissionHandled { result } => {
|
||||
model.pending_submit = false;
|
||||
match result {
|
||||
SubmissionOutcome::MessageSent | SubmissionOutcome::CommandExecuted => {
|
||||
model.draft.clear();
|
||||
if model.mode == InputMode::Editing {
|
||||
model.mode = InputMode::Normal;
|
||||
}
|
||||
}
|
||||
SubmissionOutcome::Failed => {}
|
||||
}
|
||||
Vec::new()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn submit_with_empty_draft_sets_error() {
|
||||
let mut model = AppModel::default();
|
||||
let effects = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
|
||||
assert!(!model.composer.pending_submit);
|
||||
assert_eq!(effects.len(), 1);
|
||||
match &effects[0] {
|
||||
AppEffect::SetStatus(message) => {
|
||||
assert!(message.contains("Cannot send empty message"));
|
||||
}
|
||||
other => panic!("unexpected effect: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn submit_with_content_requests_processing() {
|
||||
let mut model = AppModel::default();
|
||||
let _ = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::DraftChanged {
|
||||
content: "hello world".into(),
|
||||
}),
|
||||
);
|
||||
|
||||
let effects = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
|
||||
assert!(model.composer.pending_submit);
|
||||
assert_eq!(effects.len(), 1);
|
||||
matches!(effects[0], AppEffect::RequestSubmit);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn submission_success_clears_draft_and_mode() {
|
||||
let mut model = AppModel::default();
|
||||
let _ = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::DraftChanged {
|
||||
content: "hello world".into(),
|
||||
}),
|
||||
);
|
||||
let _ = update(&mut model, AppEvent::Composer(ComposerEvent::Submit));
|
||||
assert!(model.composer.pending_submit);
|
||||
|
||||
let effects = update(
|
||||
&mut model,
|
||||
AppEvent::Composer(ComposerEvent::SubmissionHandled {
|
||||
result: SubmissionOutcome::MessageSent,
|
||||
}),
|
||||
);
|
||||
|
||||
assert!(effects.is_empty());
|
||||
assert!(!model.composer.pending_submit);
|
||||
assert!(model.composer.draft.is_empty());
|
||||
assert_eq!(model.composer.mode, InputMode::Normal);
|
||||
}
|
||||
}
|
||||
52
crates/owlen-tui/src/app/worker.rs
Normal file
52
crates/owlen-tui/src/app/worker.rs
Normal file
@@ -0,0 +1,52 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use tokio::{sync::mpsc, time};
|
||||
|
||||
use owlen_core::provider::ProviderManager;
|
||||
|
||||
use super::AppMessage;
|
||||
|
||||
const HEALTH_CHECK_INTERVAL: Duration = Duration::from_secs(30);
|
||||
|
||||
/// Periodically refresh provider health and emit status updates into the app's
|
||||
/// message channel. Exits automatically once the receiver side of the channel
|
||||
/// is dropped.
|
||||
pub async fn background_worker(
|
||||
provider_manager: Arc<ProviderManager>,
|
||||
message_tx: mpsc::UnboundedSender<AppMessage>,
|
||||
) {
|
||||
let mut interval = time::interval(HEALTH_CHECK_INTERVAL);
|
||||
let mut last_statuses = provider_manager.provider_statuses().await;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
if message_tx.is_closed() {
|
||||
break;
|
||||
}
|
||||
|
||||
let statuses = provider_manager.refresh_health().await;
|
||||
|
||||
for (provider_id, status) in statuses {
|
||||
let changed = match last_statuses.get(&provider_id) {
|
||||
Some(previous) => previous != &status,
|
||||
None => true,
|
||||
};
|
||||
|
||||
last_statuses.insert(provider_id.clone(), status);
|
||||
|
||||
if changed
|
||||
&& message_tx
|
||||
.send(AppMessage::ProviderStatus {
|
||||
provider_id,
|
||||
status,
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
// Receiver dropped; terminate worker.
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,5 @@
|
||||
use anyhow::Result;
|
||||
use owlen_core::session::SessionController;
|
||||
use owlen_core::session::{ControllerEvent, SessionController};
|
||||
use owlen_core::ui::{AppState, InputMode};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
@@ -16,11 +16,12 @@ pub struct CodeApp {
|
||||
impl CodeApp {
|
||||
pub async fn new(
|
||||
mut controller: SessionController,
|
||||
controller_event_rx: mpsc::UnboundedReceiver<ControllerEvent>,
|
||||
) -> Result<(Self, mpsc::UnboundedReceiver<SessionEvent>)> {
|
||||
controller
|
||||
.conversation_mut()
|
||||
.push_system_message(DEFAULT_SYSTEM_PROMPT.to_string());
|
||||
let (inner, rx) = ChatApp::new(controller).await?;
|
||||
let (inner, rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
Ok((Self { inner }, rx))
|
||||
}
|
||||
|
||||
@@ -28,8 +29,8 @@ impl CodeApp {
|
||||
self.inner.handle_event(event).await
|
||||
}
|
||||
|
||||
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event)
|
||||
pub async fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
self.inner.handle_session_event(event).await
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> InputMode {
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
//! Command catalog and lookup utilities for the command palette.
|
||||
pub mod registry;
|
||||
pub use registry::{AppCommand, CommandRegistry};
|
||||
|
||||
// Command catalog and lookup utilities for the command palette.
|
||||
|
||||
/// Metadata describing a single command keyword.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
@@ -14,7 +17,15 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "q",
|
||||
description: "Alias for quit",
|
||||
description: "Close the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "w",
|
||||
description: "Save the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "write",
|
||||
description: "Alias for w",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "clear",
|
||||
@@ -25,12 +36,16 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
description: "Alias for clear",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "w",
|
||||
description: "Alias for write",
|
||||
keyword: "save",
|
||||
description: "Alias for w",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "save",
|
||||
description: "Alias for write",
|
||||
keyword: "wq",
|
||||
description: "Save and close the active file",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "x",
|
||||
description: "Alias for wq",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "load",
|
||||
@@ -44,6 +59,10 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
keyword: "open",
|
||||
description: "Open a file in the code view",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "create",
|
||||
description: "Create a file (creates missing directories)",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "close",
|
||||
description: "Close the active code view",
|
||||
@@ -68,9 +87,13 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
keyword: "sessions",
|
||||
description: "List saved sessions",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "session save",
|
||||
description: "Save the current conversation",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "help",
|
||||
description: "Show help documentation",
|
||||
description: "Open the help overlay",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "h",
|
||||
@@ -82,7 +105,23 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "provider",
|
||||
description: "Switch active provider",
|
||||
description: "Switch provider or set its mode",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud setup",
|
||||
description: "Configure Ollama Cloud credentials",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud status",
|
||||
description: "Check Ollama Cloud connectivity",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud models",
|
||||
description: "List models available in Ollama Cloud",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "cloud logout",
|
||||
description: "Remove stored Ollama Cloud credentials",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "model info",
|
||||
@@ -104,6 +143,18 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
keyword: "models info",
|
||||
description: "Prefetch detailed information for all models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --local",
|
||||
description: "Open model picker focused on local models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --cloud",
|
||||
description: "Open model picker focused on cloud models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "models --available",
|
||||
description: "Open model picker showing available models",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "new",
|
||||
description: "Start a new conversation",
|
||||
@@ -128,6 +179,10 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
keyword: "reload",
|
||||
description: "Reload configuration and themes",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "markdown",
|
||||
description: "Toggle markdown rendering",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "e",
|
||||
description: "Edit a file",
|
||||
@@ -160,6 +215,38 @@ const COMMANDS: &[CommandSpec] = &[
|
||||
keyword: "stop-agent",
|
||||
description: "Stop the running agent",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent status",
|
||||
description: "Show current agent status",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent start",
|
||||
description: "Arm the agent for the next request",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "agent stop",
|
||||
description: "Stop the running agent",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "layout save",
|
||||
description: "Persist the current pane layout",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "layout load",
|
||||
description: "Restore the last saved pane layout",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "files",
|
||||
description: "Toggle the files panel",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "explorer",
|
||||
description: "Alias for files",
|
||||
},
|
||||
CommandSpec {
|
||||
keyword: "debug log",
|
||||
description: "Toggle the debug log panel",
|
||||
},
|
||||
];
|
||||
|
||||
/// Return the static catalog of commands.
|
||||
@@ -168,29 +255,35 @@ pub fn all() -> &'static [CommandSpec] {
|
||||
}
|
||||
|
||||
/// Return the default suggestion list (all command keywords).
|
||||
pub fn default_suggestions() -> Vec<String> {
|
||||
COMMANDS
|
||||
.iter()
|
||||
.map(|spec| spec.keyword.to_string())
|
||||
.collect()
|
||||
pub fn default_suggestions() -> Vec<CommandSpec> {
|
||||
COMMANDS.to_vec()
|
||||
}
|
||||
|
||||
/// Generate keyword suggestions for the given input.
|
||||
pub fn suggestions(input: &str) -> Vec<String> {
|
||||
pub fn suggestions(input: &str) -> Vec<CommandSpec> {
|
||||
let trimmed = input.trim();
|
||||
if trimmed.is_empty() {
|
||||
return default_suggestions();
|
||||
}
|
||||
COMMANDS
|
||||
|
||||
let mut matches: Vec<(usize, usize, CommandSpec)> = COMMANDS
|
||||
.iter()
|
||||
.filter_map(|spec| {
|
||||
if spec.keyword.starts_with(trimmed) {
|
||||
Some(spec.keyword.to_string())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
match_score(spec.keyword, trimmed).map(|score| (score.0, score.1, *spec))
|
||||
})
|
||||
.collect()
|
||||
.collect();
|
||||
|
||||
if matches.is_empty() {
|
||||
return default_suggestions();
|
||||
}
|
||||
|
||||
matches.sort_by(|a, b| {
|
||||
a.0.cmp(&b.0)
|
||||
.then(a.1.cmp(&b.1))
|
||||
.then(a.2.keyword.cmp(b.2.keyword))
|
||||
});
|
||||
|
||||
matches.into_iter().map(|(_, _, spec)| spec).collect()
|
||||
}
|
||||
|
||||
pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
|
||||
@@ -209,7 +302,7 @@ pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
|
||||
if candidate_normalized == query_normalized {
|
||||
Some((0, candidate.len()))
|
||||
} else if candidate_normalized.starts_with(&query_normalized) {
|
||||
Some((1, candidate.len()))
|
||||
Some((1, 0))
|
||||
} else if let Some(pos) = candidate_normalized.find(&query_normalized) {
|
||||
Some((2, pos))
|
||||
} else if is_subsequence(&candidate_normalized, &query_normalized) {
|
||||
@@ -219,6 +312,19 @@ pub fn match_score(candidate: &str, query: &str) -> Option<(usize, usize)> {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn suggestions_prioritize_agent_start() {
|
||||
let results = suggestions("agent st");
|
||||
assert!(!results.is_empty());
|
||||
assert_eq!(results[0].keyword, "agent start");
|
||||
assert!(results.iter().any(|spec| spec.keyword == "agent stop"));
|
||||
}
|
||||
}
|
||||
|
||||
fn is_subsequence(text: &str, pattern: &str) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return true;
|
||||
|
||||
107
crates/owlen-tui/src/commands/registry.rs
Normal file
107
crates/owlen-tui/src/commands/registry.rs
Normal file
@@ -0,0 +1,107 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use owlen_core::ui::FocusedPanel;
|
||||
|
||||
use crate::widgets::model_picker::FilterMode;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum AppCommand {
|
||||
OpenModelPicker(Option<FilterMode>),
|
||||
OpenCommandPalette,
|
||||
CycleFocusForward,
|
||||
CycleFocusBackward,
|
||||
FocusPanel(FocusedPanel),
|
||||
ComposerSubmit,
|
||||
EnterCommandMode,
|
||||
ToggleDebugLog,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct CommandRegistry {
|
||||
commands: HashMap<String, AppCommand>,
|
||||
}
|
||||
|
||||
impl CommandRegistry {
|
||||
pub fn new() -> Self {
|
||||
let mut commands = HashMap::new();
|
||||
|
||||
commands.insert(
|
||||
"model.open_all".to_string(),
|
||||
AppCommand::OpenModelPicker(None),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_local".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::LocalOnly)),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_cloud".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::CloudOnly)),
|
||||
);
|
||||
commands.insert(
|
||||
"model.open_available".to_string(),
|
||||
AppCommand::OpenModelPicker(Some(FilterMode::Available)),
|
||||
);
|
||||
commands.insert("palette.open".to_string(), AppCommand::OpenCommandPalette);
|
||||
commands.insert("focus.next".to_string(), AppCommand::CycleFocusForward);
|
||||
commands.insert("focus.prev".to_string(), AppCommand::CycleFocusBackward);
|
||||
commands.insert(
|
||||
"focus.files".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Files),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.chat".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Chat),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.thinking".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Thinking),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.input".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Input),
|
||||
);
|
||||
commands.insert(
|
||||
"focus.code".to_string(),
|
||||
AppCommand::FocusPanel(FocusedPanel::Code),
|
||||
);
|
||||
commands.insert("composer.submit".to_string(), AppCommand::ComposerSubmit);
|
||||
commands.insert("mode.command".to_string(), AppCommand::EnterCommandMode);
|
||||
commands.insert("debug.toggle".to_string(), AppCommand::ToggleDebugLog);
|
||||
|
||||
Self { commands }
|
||||
}
|
||||
|
||||
pub fn resolve(&self, command: &str) -> Option<AppCommand> {
|
||||
self.commands.get(command).copied()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CommandRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_known_command() {
|
||||
let registry = CommandRegistry::new();
|
||||
assert_eq!(
|
||||
registry.resolve("focus.next"),
|
||||
Some(AppCommand::CycleFocusForward)
|
||||
);
|
||||
assert_eq!(
|
||||
registry.resolve("model.open_cloud"),
|
||||
Some(AppCommand::OpenModelPicker(Some(FilterMode::CloudOnly)))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_unknown_command() {
|
||||
let registry = CommandRegistry::new();
|
||||
assert_eq!(registry.resolve("does.not.exist"), None);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
pub use owlen_core::config::{
|
||||
Config, DEFAULT_CONFIG_PATH, GeneralSettings, InputSettings, StorageSettings, UiSettings,
|
||||
default_config_path, ensure_ollama_config, ensure_provider_config, session_timeout,
|
||||
Config, DEFAULT_CONFIG_PATH, GeneralSettings, IconMode, InputSettings, StorageSettings,
|
||||
UiSettings, default_config_path, ensure_ollama_config, ensure_provider_config, session_timeout,
|
||||
};
|
||||
|
||||
/// Attempt to load configuration from default location
|
||||
|
||||
@@ -17,6 +17,22 @@ pub enum Event {
|
||||
Tick,
|
||||
}
|
||||
|
||||
/// Convert a raw crossterm event into an application event.
|
||||
pub fn from_crossterm_event(raw: crossterm::event::Event) -> Option<Event> {
|
||||
match raw {
|
||||
crossterm::event::Event::Key(key) => {
|
||||
if key.kind == KeyEventKind::Press {
|
||||
Some(Event::Key(key))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
crossterm::event::Event::Resize(width, height) => Some(Event::Resize(width, height)),
|
||||
crossterm::event::Event::Paste(text) => Some(Event::Paste(text)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Event handler that captures terminal events and sends them to the application
|
||||
pub struct EventHandler {
|
||||
sender: mpsc::UnboundedSender<Event>,
|
||||
@@ -52,20 +68,8 @@ impl EventHandler {
|
||||
if event::poll(timeout).unwrap_or(false) {
|
||||
match event::read() {
|
||||
Ok(event) => {
|
||||
match event {
|
||||
crossterm::event::Event::Key(key) => {
|
||||
// Only handle KeyEventKind::Press to avoid duplicate events
|
||||
if key.kind == KeyEventKind::Press {
|
||||
let _ = self.sender.send(Event::Key(key));
|
||||
}
|
||||
}
|
||||
crossterm::event::Event::Resize(width, height) => {
|
||||
let _ = self.sender.send(Event::Resize(width, height));
|
||||
}
|
||||
crossterm::event::Event::Paste(text) => {
|
||||
let _ = self.sender.send(Event::Paste(text));
|
||||
}
|
||||
_ => {}
|
||||
if let Some(converted) = from_crossterm_event(event) {
|
||||
let _ = self.sender.send(converted);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
||||
160
crates/owlen-tui/src/highlight.rs
Normal file
160
crates/owlen-tui/src/highlight.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
use once_cell::sync::Lazy;
|
||||
use ratatui::style::{Color as TuiColor, Modifier, Style as TuiStyle};
|
||||
use std::path::{Path, PathBuf};
|
||||
use syntect::easy::HighlightLines;
|
||||
use syntect::highlighting::{FontStyle, Style as SynStyle, Theme, ThemeSet};
|
||||
use syntect::parsing::{SyntaxReference, SyntaxSet};
|
||||
|
||||
static SYNTAX_SET: Lazy<SyntaxSet> = Lazy::new(SyntaxSet::load_defaults_newlines);
|
||||
static THEME_SET: Lazy<ThemeSet> = Lazy::new(ThemeSet::load_defaults);
|
||||
static THEME: Lazy<Theme> = Lazy::new(|| {
|
||||
THEME_SET
|
||||
.themes
|
||||
.get("base16-ocean.dark")
|
||||
.cloned()
|
||||
.or_else(|| THEME_SET.themes.values().next().cloned())
|
||||
.unwrap_or_default()
|
||||
});
|
||||
|
||||
fn select_syntax(path_hint: Option<&Path>) -> &'static SyntaxReference {
|
||||
if let Some(path) = path_hint {
|
||||
if let Ok(Some(syntax)) = SYNTAX_SET.find_syntax_for_file(path) {
|
||||
return syntax;
|
||||
}
|
||||
if let Some(ext) = path.extension().and_then(|ext| ext.to_str()) {
|
||||
if let Some(syntax) = SYNTAX_SET.find_syntax_by_extension(ext) {
|
||||
return syntax;
|
||||
}
|
||||
}
|
||||
if let Some(name) = path.file_name().and_then(|name| name.to_str()) {
|
||||
if let Some(syntax) = SYNTAX_SET.find_syntax_by_token(name) {
|
||||
return syntax;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SYNTAX_SET.find_syntax_plain_text()
|
||||
}
|
||||
|
||||
fn select_syntax_for_language(language: Option<&str>) -> &'static SyntaxReference {
|
||||
let token = language
|
||||
.map(|lang| lang.trim().to_ascii_lowercase())
|
||||
.filter(|lang| !lang.is_empty());
|
||||
|
||||
if let Some(token) = token {
|
||||
let mut attempts: Vec<&str> = vec![token.as_str()];
|
||||
match token.as_str() {
|
||||
"c++" => attempts.extend(["cpp", "c"]),
|
||||
"c#" | "cs" => attempts.extend(["csharp", "cs"]),
|
||||
"shell" => attempts.extend(["bash", "sh"]),
|
||||
"typescript" | "ts" => attempts.extend(["typescript", "ts", "tsx"]),
|
||||
"javascript" | "js" => attempts.extend(["javascript", "js", "jsx"]),
|
||||
"py" => attempts.push("python"),
|
||||
"rs" => attempts.push("rust"),
|
||||
"yml" => attempts.push("yaml"),
|
||||
other => {
|
||||
if let Some(stripped) = other.strip_prefix('.') {
|
||||
attempts.push(stripped);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for candidate in attempts {
|
||||
if let Some(syntax) = SYNTAX_SET.find_syntax_by_token(candidate) {
|
||||
return syntax;
|
||||
}
|
||||
if let Some(syntax) = SYNTAX_SET.find_syntax_by_extension(candidate) {
|
||||
return syntax;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
SYNTAX_SET.find_syntax_plain_text()
|
||||
}
|
||||
|
||||
fn path_hint_from_components(absolute: Option<&Path>, display: Option<&str>) -> Option<PathBuf> {
|
||||
if let Some(abs) = absolute {
|
||||
return Some(abs.to_path_buf());
|
||||
}
|
||||
display.map(PathBuf::from)
|
||||
}
|
||||
|
||||
fn style_from_syntect(style: SynStyle) -> TuiStyle {
|
||||
let mut tui_style = TuiStyle::default().fg(TuiColor::Rgb(
|
||||
style.foreground.r,
|
||||
style.foreground.g,
|
||||
style.foreground.b,
|
||||
));
|
||||
|
||||
let mut modifiers = Modifier::empty();
|
||||
if style.font_style.contains(FontStyle::BOLD) {
|
||||
modifiers |= Modifier::BOLD;
|
||||
}
|
||||
if style.font_style.contains(FontStyle::ITALIC) {
|
||||
modifiers |= Modifier::ITALIC;
|
||||
}
|
||||
if style.font_style.contains(FontStyle::UNDERLINE) {
|
||||
modifiers |= Modifier::UNDERLINED;
|
||||
}
|
||||
|
||||
if !modifiers.is_empty() {
|
||||
tui_style = tui_style.add_modifier(modifiers);
|
||||
}
|
||||
|
||||
tui_style
|
||||
}
|
||||
|
||||
pub fn build_highlighter(
|
||||
absolute: Option<&Path>,
|
||||
display: Option<&str>,
|
||||
) -> HighlightLines<'static> {
|
||||
let hint_path = path_hint_from_components(absolute, display);
|
||||
let syntax = select_syntax(hint_path.as_deref());
|
||||
HighlightLines::new(syntax, &THEME)
|
||||
}
|
||||
|
||||
pub fn highlight_line(
|
||||
highlighter: &mut HighlightLines<'static>,
|
||||
line: &str,
|
||||
) -> Vec<(TuiStyle, String)> {
|
||||
let mut segments = Vec::new();
|
||||
match highlighter.highlight_line(line, &SYNTAX_SET) {
|
||||
Ok(result) => {
|
||||
for (style, piece) in result {
|
||||
let tui_style = style_from_syntect(style);
|
||||
segments.push((tui_style, piece.to_string()));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
segments.push((TuiStyle::default(), line.to_string()));
|
||||
}
|
||||
}
|
||||
|
||||
if segments.is_empty() {
|
||||
segments.push((TuiStyle::default(), String::new()));
|
||||
}
|
||||
|
||||
segments
|
||||
}
|
||||
|
||||
pub fn build_highlighter_for_language(language: Option<&str>) -> HighlightLines<'static> {
|
||||
let syntax = select_syntax_for_language(language);
|
||||
HighlightLines::new(syntax, &THEME)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn rust_highlighting_produces_colored_segment() {
|
||||
let mut highlighter = build_highlighter_for_language(Some("rust"));
|
||||
let segments = highlight_line(&mut highlighter, "fn main() {}");
|
||||
assert!(
|
||||
segments
|
||||
.iter()
|
||||
.any(|(style, text)| style.fg.is_some() && !text.trim().is_empty()),
|
||||
"Expected at least one colored segment"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once Rust 2024 let-chains are available
|
||||
|
||||
//! # Owlen TUI
|
||||
//!
|
||||
//! This crate contains all the logic for the terminal user interface (TUI) of Owlen.
|
||||
@@ -12,15 +14,20 @@
|
||||
//! - `events`: Event handling for user input and other asynchronous actions.
|
||||
//! - `ui`: The rendering logic for all TUI components.
|
||||
|
||||
pub mod app;
|
||||
pub mod chat_app;
|
||||
pub mod code_app;
|
||||
pub mod commands;
|
||||
pub mod config;
|
||||
pub mod events;
|
||||
pub mod highlight;
|
||||
pub mod model_info_panel;
|
||||
pub mod slash;
|
||||
pub mod state;
|
||||
pub mod toast;
|
||||
pub mod tui_controller;
|
||||
pub mod ui;
|
||||
pub mod widgets;
|
||||
|
||||
pub use chat_app::{ChatApp, SessionEvent};
|
||||
pub use code_app::CodeApp;
|
||||
|
||||
238
crates/owlen-tui/src/slash.rs
Normal file
238
crates/owlen-tui/src/slash.rs
Normal file
@@ -0,0 +1,238 @@
|
||||
//! Slash command parsing for chat input.
|
||||
//!
|
||||
//! Provides lightweight handling for inline commands such as `/summarize`
|
||||
//! and `/testplan`. The parser returns owned data so callers can prepare
|
||||
//! requests immediately without additional lifetime juggling.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::str::FromStr;
|
||||
use std::sync::{OnceLock, RwLock};
|
||||
|
||||
/// Supported slash commands.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum SlashCommand {
|
||||
Summarize { count: Option<usize> },
|
||||
Explain { snippet: String },
|
||||
Refactor { path: String },
|
||||
TestPlan,
|
||||
Compact,
|
||||
McpTool { server: String, tool: String },
|
||||
}
|
||||
|
||||
/// Errors emitted when parsing invalid slash input.
|
||||
#[derive(Debug)]
|
||||
pub enum SlashError {
|
||||
UnknownCommand(String),
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl fmt::Display for SlashError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
SlashError::UnknownCommand(cmd) => write!(f, "unknown slash command: {cmd}"),
|
||||
SlashError::Message(msg) => f.write_str(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for SlashError {}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct McpSlashCommand {
|
||||
pub server: String,
|
||||
pub tool: String,
|
||||
pub keyword: String,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
impl McpSlashCommand {
|
||||
pub fn new(
|
||||
server: impl Into<String>,
|
||||
tool: impl Into<String>,
|
||||
description: Option<String>,
|
||||
) -> Self {
|
||||
let server = server.into();
|
||||
let tool = tool.into();
|
||||
let keyword = format!(
|
||||
"mcp__{}__{}",
|
||||
canonicalize_component(&server),
|
||||
canonicalize_component(&tool)
|
||||
);
|
||||
Self {
|
||||
server,
|
||||
tool,
|
||||
keyword,
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static MCP_COMMANDS: OnceLock<RwLock<HashMap<String, McpSlashCommand>>> = OnceLock::new();
|
||||
|
||||
fn dynamic_registry() -> &'static RwLock<HashMap<String, McpSlashCommand>> {
|
||||
MCP_COMMANDS.get_or_init(|| RwLock::new(HashMap::new()))
|
||||
}
|
||||
|
||||
pub fn set_mcp_commands(commands: impl IntoIterator<Item = McpSlashCommand>) {
|
||||
let registry = dynamic_registry();
|
||||
let mut guard = registry.write().expect("MCP command registry poisoned");
|
||||
guard.clear();
|
||||
for command in commands {
|
||||
guard.insert(command.keyword.clone(), command);
|
||||
}
|
||||
}
|
||||
|
||||
fn find_mcp_command(keyword: &str) -> Option<McpSlashCommand> {
|
||||
let registry = dynamic_registry();
|
||||
let guard = registry.read().expect("MCP command registry poisoned");
|
||||
guard.get(keyword).cloned()
|
||||
}
|
||||
|
||||
fn canonicalize_component(input: &str) -> String {
|
||||
let mut out = String::new();
|
||||
let mut last_was_underscore = false;
|
||||
for ch in input.chars() {
|
||||
let mapped = if ch.is_ascii_alphanumeric() {
|
||||
ch.to_ascii_lowercase()
|
||||
} else {
|
||||
'_'
|
||||
};
|
||||
if mapped == '_' {
|
||||
if !last_was_underscore {
|
||||
out.push('_');
|
||||
last_was_underscore = true;
|
||||
}
|
||||
} else {
|
||||
out.push(mapped);
|
||||
last_was_underscore = false;
|
||||
}
|
||||
}
|
||||
if out.is_empty() { "_".to_string() } else { out }
|
||||
}
|
||||
|
||||
/// Attempt to parse a slash command from the provided input.
|
||||
pub fn parse(input: &str) -> Result<Option<SlashCommand>, SlashError> {
|
||||
let trimmed = input.trim();
|
||||
if !trimmed.starts_with('/') {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let body = trimmed.trim_start_matches('/');
|
||||
if body.is_empty() {
|
||||
return Err(SlashError::Message("missing command name after '/'".into()));
|
||||
}
|
||||
|
||||
let mut parts = body.split_whitespace();
|
||||
let command = parts.next().unwrap();
|
||||
let remainder = parts.collect::<Vec<_>>();
|
||||
|
||||
if let Some(dynamic) = find_mcp_command(command) {
|
||||
if !remainder.is_empty() {
|
||||
return Err(SlashError::Message(format!(
|
||||
"/{} does not accept arguments",
|
||||
dynamic.keyword
|
||||
)));
|
||||
}
|
||||
return Ok(Some(SlashCommand::McpTool {
|
||||
server: dynamic.server,
|
||||
tool: dynamic.tool,
|
||||
}));
|
||||
}
|
||||
|
||||
let cmd = match command {
|
||||
"summarize" => {
|
||||
let count = remainder
|
||||
.first()
|
||||
.and_then(|value| usize::from_str(value).ok());
|
||||
SlashCommand::Summarize { count }
|
||||
}
|
||||
"explain" => {
|
||||
if remainder.is_empty() {
|
||||
return Err(SlashError::Message(
|
||||
"usage: /explain <code snippet or description>".into(),
|
||||
));
|
||||
}
|
||||
SlashCommand::Explain {
|
||||
snippet: remainder.join(" "),
|
||||
}
|
||||
}
|
||||
"refactor" => {
|
||||
if remainder.is_empty() {
|
||||
return Err(SlashError::Message(
|
||||
"usage: /refactor <relative/path/to/file>".into(),
|
||||
));
|
||||
}
|
||||
SlashCommand::Refactor {
|
||||
path: remainder.join(" "),
|
||||
}
|
||||
}
|
||||
"testplan" => SlashCommand::TestPlan,
|
||||
"compact" => SlashCommand::Compact,
|
||||
other => return Err(SlashError::UnknownCommand(other.to_string())),
|
||||
};
|
||||
|
||||
Ok(Some(cmd))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn ignores_non_command_input() {
|
||||
let result = parse("hello world").unwrap();
|
||||
assert!(result.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_summarize_with_count() {
|
||||
let command = parse("/summarize 10").unwrap().expect("expected command");
|
||||
match command {
|
||||
SlashCommand::Summarize { count } => assert_eq!(count, Some(10)),
|
||||
other => panic!("unexpected command: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn returns_error_for_unknown_command() {
|
||||
let err = parse("/unknown").unwrap_err();
|
||||
assert_eq!(err.to_string(), "unknown slash command: unknown");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parses_registered_mcp_command() {
|
||||
set_mcp_commands(Vec::new());
|
||||
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
||||
|
||||
let command = parse("/mcp__github__list_prs")
|
||||
.unwrap()
|
||||
.expect("expected command");
|
||||
match command {
|
||||
SlashCommand::McpTool { server, tool } => {
|
||||
assert_eq!(server, "github");
|
||||
assert_eq!(tool, "list_prs");
|
||||
}
|
||||
other => panic!("unexpected command variant: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_mcp_command_with_arguments() {
|
||||
set_mcp_commands(Vec::new());
|
||||
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
||||
|
||||
let err = parse("/mcp__github__list_prs extra").unwrap_err();
|
||||
assert_eq!(
|
||||
err.to_string(),
|
||||
"/mcp__github__list_prs does not accept arguments"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn canonicalizes_mcp_command_components() {
|
||||
set_mcp_commands(Vec::new());
|
||||
let entry = McpSlashCommand::new("GitHub", "list/prs", None);
|
||||
assert_eq!(entry.keyword, "mcp__github__list_prs");
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,32 @@
|
||||
use crate::commands;
|
||||
use crate::commands::{self, CommandSpec};
|
||||
use std::collections::{HashSet, VecDeque};
|
||||
|
||||
const MAX_RESULTS: usize = 12;
|
||||
const MAX_HISTORY_RESULTS: usize = 4;
|
||||
const HISTORY_CAPACITY: usize = 20;
|
||||
|
||||
/// Encapsulates the command-line style palette used in command mode.
|
||||
///
|
||||
/// The palette keeps track of the raw buffer, matching suggestions, and the
|
||||
/// currently highlighted suggestion index. It contains no terminal-specific
|
||||
/// logic which makes it straightforward to unit test.
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PaletteGroup {
|
||||
History,
|
||||
Command,
|
||||
Model,
|
||||
Provider,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaletteSuggestion {
|
||||
pub value: String,
|
||||
pub label: String,
|
||||
pub detail: Option<String>,
|
||||
pub group: PaletteGroup,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelPaletteEntry {
|
||||
pub id: String,
|
||||
@@ -25,10 +47,11 @@ impl ModelPaletteEntry {
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct CommandPalette {
|
||||
buffer: String,
|
||||
suggestions: Vec<String>,
|
||||
suggestions: Vec<PaletteSuggestion>,
|
||||
selected: usize,
|
||||
models: Vec<ModelPaletteEntry>,
|
||||
providers: Vec<String>,
|
||||
history: VecDeque<String>,
|
||||
}
|
||||
|
||||
impl CommandPalette {
|
||||
@@ -40,7 +63,7 @@ impl CommandPalette {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn suggestions(&self) -> &[String] {
|
||||
pub fn suggestions(&self) -> &[PaletteSuggestion] {
|
||||
&self.suggestions
|
||||
}
|
||||
|
||||
@@ -54,6 +77,28 @@ impl CommandPalette {
|
||||
self.selected = 0;
|
||||
}
|
||||
|
||||
pub fn remember(&mut self, value: impl AsRef<str>) {
|
||||
let trimmed = value.as_ref().trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Avoid duplicate consecutive entries by removing any existing matching value.
|
||||
if let Some(pos) = self
|
||||
.history
|
||||
.iter()
|
||||
.position(|entry| entry.eq_ignore_ascii_case(trimmed))
|
||||
{
|
||||
self.history.remove(pos);
|
||||
}
|
||||
|
||||
self.history.push_back(trimmed.to_string());
|
||||
|
||||
while self.history.len() > HISTORY_CAPACITY {
|
||||
self.history.pop_front();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_buffer(&mut self, value: impl Into<String>) {
|
||||
self.buffer = value.into();
|
||||
self.refresh_suggestions();
|
||||
@@ -98,11 +143,11 @@ impl CommandPalette {
|
||||
.get(self.selected)
|
||||
.cloned()
|
||||
.or_else(|| self.suggestions.first().cloned());
|
||||
if let Some(value) = selected.clone() {
|
||||
self.buffer = value;
|
||||
if let Some(entry) = selected.clone() {
|
||||
self.buffer = entry.value.clone();
|
||||
self.refresh_suggestions();
|
||||
}
|
||||
selected
|
||||
selected.map(|entry| entry.value)
|
||||
}
|
||||
|
||||
pub fn refresh_suggestions(&mut self) {
|
||||
@@ -119,40 +164,177 @@ impl CommandPalette {
|
||||
}
|
||||
}
|
||||
|
||||
fn dynamic_suggestions(&self, trimmed: &str) -> Vec<String> {
|
||||
if let Some(rest) = trimmed.strip_prefix("model ") {
|
||||
let suggestions = self.model_suggestions("model", rest.trim());
|
||||
if suggestions.is_empty() {
|
||||
commands::suggestions(trimmed)
|
||||
} else {
|
||||
suggestions
|
||||
fn dynamic_suggestions(&self, trimmed: &str) -> Vec<PaletteSuggestion> {
|
||||
let lowered = trimmed.to_ascii_lowercase();
|
||||
let mut results: Vec<PaletteSuggestion> = Vec::new();
|
||||
let mut seen: HashSet<String> = HashSet::new();
|
||||
|
||||
fn push_entries(
|
||||
results: &mut Vec<PaletteSuggestion>,
|
||||
seen: &mut HashSet<String>,
|
||||
entries: Vec<PaletteSuggestion>,
|
||||
) {
|
||||
for entry in entries {
|
||||
if seen.insert(entry.value.to_ascii_lowercase()) {
|
||||
results.push(entry);
|
||||
}
|
||||
if results.len() >= MAX_RESULTS {
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else if let Some(rest) = trimmed.strip_prefix("m ") {
|
||||
let suggestions = self.model_suggestions("m", rest.trim());
|
||||
if suggestions.is_empty() {
|
||||
commands::suggestions(trimmed)
|
||||
} else {
|
||||
suggestions
|
||||
}
|
||||
} else if let Some(rest) = trimmed.strip_prefix("provider ") {
|
||||
let suggestions = self.provider_suggestions("provider", rest.trim());
|
||||
if suggestions.is_empty() {
|
||||
commands::suggestions(trimmed)
|
||||
} else {
|
||||
suggestions
|
||||
}
|
||||
} else {
|
||||
commands::suggestions(trimmed)
|
||||
}
|
||||
|
||||
let history = self.history_suggestions(trimmed);
|
||||
push_entries(&mut results, &mut seen, history);
|
||||
if results.len() >= MAX_RESULTS {
|
||||
return results;
|
||||
}
|
||||
|
||||
if lowered.starts_with("model ") {
|
||||
let rest = trimmed[5..].trim();
|
||||
push_entries(
|
||||
&mut results,
|
||||
&mut seen,
|
||||
self.model_suggestions("model", rest),
|
||||
);
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
if lowered.starts_with("m ") {
|
||||
let rest = trimmed[2..].trim();
|
||||
push_entries(&mut results, &mut seen, self.model_suggestions("m", rest));
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
if lowered == "model" {
|
||||
push_entries(&mut results, &mut seen, self.model_suggestions("model", ""));
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
if lowered.starts_with("provider ") {
|
||||
let rest = trimmed[9..].trim();
|
||||
push_entries(
|
||||
&mut results,
|
||||
&mut seen,
|
||||
self.provider_suggestions("provider", rest),
|
||||
);
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
if lowered == "provider" {
|
||||
push_entries(
|
||||
&mut results,
|
||||
&mut seen,
|
||||
self.provider_suggestions("provider", ""),
|
||||
);
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
// General query – combine commands, models, and providers using fuzzy order.
|
||||
push_entries(&mut results, &mut seen, self.command_entries(trimmed));
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(
|
||||
&mut results,
|
||||
&mut seen,
|
||||
self.model_suggestions("model", trimmed),
|
||||
);
|
||||
}
|
||||
if results.len() < MAX_RESULTS {
|
||||
push_entries(
|
||||
&mut results,
|
||||
&mut seen,
|
||||
self.provider_suggestions("provider", trimmed),
|
||||
);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn model_suggestions(&self, keyword: &str, query: &str) -> Vec<String> {
|
||||
fn history_suggestions(&self, query: &str) -> Vec<PaletteSuggestion> {
|
||||
if self.history.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
if query.trim().is_empty() {
|
||||
return self
|
||||
.history
|
||||
.iter()
|
||||
.rev()
|
||||
.take(MAX_HISTORY_RESULTS)
|
||||
.map(|value| PaletteSuggestion {
|
||||
value: value.to_string(),
|
||||
label: value.to_string(),
|
||||
detail: Some("Recent command".to_string()),
|
||||
group: PaletteGroup::History,
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
let mut matches: Vec<(usize, usize, usize, &String)> = self
|
||||
.history
|
||||
.iter()
|
||||
.rev()
|
||||
.enumerate()
|
||||
.filter_map(|(recency, value)| {
|
||||
commands::match_score(value, query)
|
||||
.map(|(primary, secondary)| (primary, secondary, recency, value))
|
||||
})
|
||||
.collect();
|
||||
|
||||
matches.sort_by(|a, b| a.0.cmp(&b.0).then(a.1.cmp(&b.1)).then(a.2.cmp(&b.2)));
|
||||
|
||||
matches
|
||||
.into_iter()
|
||||
.take(MAX_HISTORY_RESULTS)
|
||||
.map(|(_, _, _, value)| PaletteSuggestion {
|
||||
value: value.to_string(),
|
||||
label: value.to_string(),
|
||||
detail: Some("Recent command".to_string()),
|
||||
group: PaletteGroup::History,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn command_entries(&self, query: &str) -> Vec<PaletteSuggestion> {
|
||||
let specs: Vec<CommandSpec> = commands::suggestions(query);
|
||||
specs
|
||||
.into_iter()
|
||||
.map(|spec| PaletteSuggestion {
|
||||
value: spec.keyword.to_string(),
|
||||
label: spec.keyword.to_string(),
|
||||
detail: Some(spec.description.to_string()),
|
||||
group: PaletteGroup::Command,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn model_suggestions(&self, keyword: &str, query: &str) -> Vec<PaletteSuggestion> {
|
||||
if query.is_empty() {
|
||||
return self
|
||||
.models
|
||||
.iter()
|
||||
.take(15)
|
||||
.map(|entry| format!("{keyword} {}", entry.id))
|
||||
.map(|entry| PaletteSuggestion {
|
||||
value: format!("{keyword} {}", entry.id),
|
||||
label: entry.display_name().to_string(),
|
||||
detail: Some(format!("Model · {}", entry.provider)),
|
||||
group: PaletteGroup::Model,
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
@@ -174,17 +356,27 @@ impl CommandPalette {
|
||||
matches
|
||||
.into_iter()
|
||||
.take(15)
|
||||
.map(|(_, _, entry)| format!("{keyword} {}", entry.id))
|
||||
.map(|(_, _, entry)| PaletteSuggestion {
|
||||
value: format!("{keyword} {}", entry.id),
|
||||
label: entry.display_name().to_string(),
|
||||
detail: Some(format!("Model · {}", entry.provider)),
|
||||
group: PaletteGroup::Model,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn provider_suggestions(&self, keyword: &str, query: &str) -> Vec<String> {
|
||||
fn provider_suggestions(&self, keyword: &str, query: &str) -> Vec<PaletteSuggestion> {
|
||||
if query.is_empty() {
|
||||
return self
|
||||
.providers
|
||||
.iter()
|
||||
.take(15)
|
||||
.map(|provider| format!("{keyword} {}", provider))
|
||||
.map(|provider| PaletteSuggestion {
|
||||
value: format!("{keyword} {}", provider),
|
||||
label: provider.to_string(),
|
||||
detail: Some("Provider".to_string()),
|
||||
group: PaletteGroup::Provider,
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
|
||||
@@ -201,7 +393,47 @@ impl CommandPalette {
|
||||
matches
|
||||
.into_iter()
|
||||
.take(15)
|
||||
.map(|(_, _, provider)| format!("{keyword} {}", provider))
|
||||
.map(|(_, _, provider)| PaletteSuggestion {
|
||||
value: format!("{keyword} {}", provider),
|
||||
label: provider.to_string(),
|
||||
detail: Some("Provider".to_string()),
|
||||
group: PaletteGroup::Provider,
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn history_entries_are_prioritized() {
|
||||
let mut palette = CommandPalette::new();
|
||||
palette.remember("open foo.rs");
|
||||
palette.remember("model llama");
|
||||
palette.ensure_suggestions();
|
||||
|
||||
let suggestions = palette.suggestions();
|
||||
assert!(!suggestions.is_empty());
|
||||
assert_eq!(suggestions[0].value, "model llama");
|
||||
assert!(matches!(suggestions[0].group, PaletteGroup::History));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn history_deduplicates_case_insensitively() {
|
||||
let mut palette = CommandPalette::new();
|
||||
palette.remember("open foo.rs");
|
||||
palette.remember("OPEN FOO.RS");
|
||||
palette.ensure_suggestions();
|
||||
|
||||
let history_entries: Vec<_> = palette
|
||||
.suggestions()
|
||||
.iter()
|
||||
.filter(|entry| matches!(entry.group, PaletteGroup::History))
|
||||
.collect();
|
||||
|
||||
assert_eq!(history_entries.len(), 1);
|
||||
assert_eq!(history_entries[0].value, "OPEN FOO.RS");
|
||||
}
|
||||
}
|
||||
|
||||
235
crates/owlen-tui/src/state/debug_log.rs
Normal file
235
crates/owlen-tui/src/state/debug_log.rs
Normal file
@@ -0,0 +1,235 @@
|
||||
use chrono::{DateTime, Local};
|
||||
use log::{Level, LevelFilter, Metadata, Record};
|
||||
use once_cell::sync::{Lazy, OnceCell};
|
||||
use regex::Regex;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Maximum number of entries to retain in the in-memory ring buffer.
|
||||
const MAX_ENTRIES: usize = 256;
|
||||
|
||||
/// Global access handle for the debug log store.
|
||||
static STORE: Lazy<DebugLogStore> = Lazy::new(DebugLogStore::default);
|
||||
static LOGGER: OnceCell<()> = OnceCell::new();
|
||||
static DEBUG_LOGGER: DebugLogger = DebugLogger;
|
||||
|
||||
/// Install the in-process logger that feeds the debug log ring buffer.
|
||||
pub fn install_global_logger() {
|
||||
LOGGER.get_or_init(|| {
|
||||
if log::set_logger(&DEBUG_LOGGER).is_ok() {
|
||||
log::set_max_level(LevelFilter::Trace);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/// Per-application state for presenting and acknowledging debug log entries.
|
||||
#[derive(Debug)]
|
||||
pub struct DebugLogState {
|
||||
visible: bool,
|
||||
last_seen_id: u64,
|
||||
}
|
||||
|
||||
impl DebugLogState {
|
||||
pub fn new() -> Self {
|
||||
let last_seen_id = STORE.latest_id();
|
||||
Self {
|
||||
visible: false,
|
||||
last_seen_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn toggle_visible(&mut self) -> bool {
|
||||
self.visible = !self.visible;
|
||||
if self.visible {
|
||||
self.mark_seen();
|
||||
}
|
||||
self.visible
|
||||
}
|
||||
|
||||
pub fn set_visible(&mut self, visible: bool) {
|
||||
self.visible = visible;
|
||||
if visible {
|
||||
self.mark_seen();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
pub fn entries(&self) -> Vec<DebugLogEntry> {
|
||||
STORE.snapshot()
|
||||
}
|
||||
|
||||
pub fn take_unseen(&mut self) -> Vec<DebugLogEntry> {
|
||||
let entries = STORE.entries_since(self.last_seen_id);
|
||||
if let Some(entry) = entries.last() {
|
||||
self.last_seen_id = entry.id;
|
||||
}
|
||||
entries
|
||||
}
|
||||
|
||||
pub fn has_unseen(&self) -> bool {
|
||||
STORE.latest_id() > self.last_seen_id
|
||||
}
|
||||
|
||||
fn mark_seen(&mut self) {
|
||||
self.last_seen_id = STORE.latest_id();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for DebugLogState {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Metadata describing a single debug log entry.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct DebugLogEntry {
|
||||
pub id: u64,
|
||||
pub timestamp: DateTime<Local>,
|
||||
pub level: Level,
|
||||
pub target: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct DebugLogStore {
|
||||
inner: Mutex<Inner>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct Inner {
|
||||
entries: VecDeque<DebugLogEntry>,
|
||||
next_id: u64,
|
||||
}
|
||||
|
||||
impl DebugLogStore {
|
||||
fn snapshot(&self) -> Vec<DebugLogEntry> {
|
||||
let inner = self.inner.lock().unwrap();
|
||||
inner.entries.iter().cloned().collect()
|
||||
}
|
||||
|
||||
fn latest_id(&self) -> u64 {
|
||||
let inner = self.inner.lock().unwrap();
|
||||
inner.next_id
|
||||
}
|
||||
|
||||
fn entries_since(&self, last_seen_id: u64) -> Vec<DebugLogEntry> {
|
||||
let inner = self.inner.lock().unwrap();
|
||||
inner
|
||||
.entries
|
||||
.iter()
|
||||
.filter(|entry| entry.id > last_seen_id)
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn push(&self, level: Level, target: &str, message: &str) -> DebugLogEntry {
|
||||
let sanitized = sanitize_message(message);
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.next_id = inner.next_id.saturating_add(1);
|
||||
let entry = DebugLogEntry {
|
||||
id: inner.next_id,
|
||||
timestamp: Local::now(),
|
||||
level,
|
||||
target: target.to_string(),
|
||||
message: sanitized,
|
||||
};
|
||||
inner.entries.push_back(entry.clone());
|
||||
while inner.entries.len() > MAX_ENTRIES {
|
||||
inner.entries.pop_front();
|
||||
}
|
||||
entry
|
||||
}
|
||||
}
|
||||
|
||||
struct DebugLogger;
|
||||
|
||||
impl log::Log for DebugLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
metadata.level() <= LevelFilter::Trace
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if !self.enabled(record.metadata()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Only persist warnings and errors in the in-memory buffer.
|
||||
if record.level() < Level::Warn {
|
||||
return;
|
||||
}
|
||||
|
||||
let message = record.args().to_string();
|
||||
let entry = STORE.push(record.level(), record.target(), &message);
|
||||
|
||||
if record.level() == Level::Error {
|
||||
eprintln!(
|
||||
"[owlen:error][{}] {}",
|
||||
entry.timestamp.format("%Y-%m-%d %H:%M:%S"),
|
||||
entry.message
|
||||
);
|
||||
} else if record.level() == Level::Warn {
|
||||
eprintln!(
|
||||
"[owlen:warn][{}] {}",
|
||||
entry.timestamp.format("%Y-%m-%d %H:%M:%S"),
|
||||
entry.message
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {}
|
||||
}
|
||||
|
||||
fn sanitize_message(message: &str) -> String {
|
||||
static AUTH_HEADER: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)\b(authorization)(\s*[:=]\s*)([^\r\n]+)").unwrap());
|
||||
static GENERIC_SECRET: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)\b(api[_-]?key|token)(\s*[:=]\s*)([^,\s;]+)").unwrap());
|
||||
static BEARER_TOKEN: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"(?i)\bBearer\s+[A-Za-z0-9._\-+/=]+").unwrap());
|
||||
|
||||
let step = AUTH_HEADER.replace_all(message, |caps: ®ex::Captures<'_>| {
|
||||
format!("{}{}<redacted>", &caps[1], &caps[2])
|
||||
});
|
||||
|
||||
let step = GENERIC_SECRET.replace_all(&step, |caps: ®ex::Captures<'_>| {
|
||||
format!("{}{}<redacted>", &caps[1], &caps[2])
|
||||
});
|
||||
|
||||
BEARER_TOKEN
|
||||
.replace_all(&step, "Bearer <redacted>")
|
||||
.into_owned()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn sanitize_masks_common_tokens() {
|
||||
let input =
|
||||
"Authorization: Bearer abc123 token=xyz456 KEY=value Authorization=Token secretStuff";
|
||||
let sanitized = sanitize_message(input);
|
||||
assert!(!sanitized.contains("abc123"));
|
||||
assert!(!sanitized.contains("xyz456"));
|
||||
assert!(!sanitized.contains("secretStuff"));
|
||||
assert_eq!(sanitized, "Authorization: <redacted>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ring_buffer_discards_old_entries() {
|
||||
install_global_logger();
|
||||
let initial_latest = STORE.latest_id();
|
||||
for idx in 0..(MAX_ENTRIES as u64 + 10) {
|
||||
let message = format!("warn #{idx}");
|
||||
STORE.push(Level::Warn, "test", &message);
|
||||
}
|
||||
|
||||
let entries = STORE.snapshot();
|
||||
assert_eq!(entries.len(), MAX_ENTRIES);
|
||||
assert!(entries.first().unwrap().id > initial_latest);
|
||||
}
|
||||
}
|
||||
320
crates/owlen-tui/src/state/file_icons.rs
Normal file
320
crates/owlen-tui/src/state/file_icons.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use std::env;
|
||||
use std::path::Path;
|
||||
|
||||
use owlen_core::config::IconMode;
|
||||
use unicode_width::UnicodeWidthChar;
|
||||
|
||||
use super::FileNode;
|
||||
|
||||
const ENV_ICON_OVERRIDE: &str = "OWLEN_TUI_ICONS";
|
||||
|
||||
/// Concrete icon sets that can be rendered in the terminal.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FileIconSet {
|
||||
Nerd,
|
||||
Ascii,
|
||||
}
|
||||
|
||||
impl FileIconSet {
|
||||
pub fn label(self) -> &'static str {
|
||||
match self {
|
||||
FileIconSet::Nerd => "Nerd",
|
||||
FileIconSet::Ascii => "ASCII",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// How the icon mode was decided.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum IconDetection {
|
||||
/// Explicit configuration (config file or CLI flag) forced the mode.
|
||||
Configured,
|
||||
/// The runtime environment variable override selected the mode.
|
||||
Environment,
|
||||
/// Automatic heuristics guessed the appropriate mode.
|
||||
Heuristic,
|
||||
}
|
||||
|
||||
/// Resolves per-file icons with configurable fallbacks.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileIconResolver {
|
||||
set: FileIconSet,
|
||||
detection: IconDetection,
|
||||
}
|
||||
|
||||
impl FileIconResolver {
|
||||
/// Construct a resolver from the configured icon preference.
|
||||
pub fn from_mode(pref: IconMode) -> Self {
|
||||
let (set, detection) = match pref {
|
||||
IconMode::Ascii => (FileIconSet::Ascii, IconDetection::Configured),
|
||||
IconMode::Nerd => (FileIconSet::Nerd, IconDetection::Configured),
|
||||
IconMode::Auto => detect_icon_set(),
|
||||
};
|
||||
Self { set, detection }
|
||||
}
|
||||
|
||||
/// Effective icon set that will be rendered.
|
||||
pub fn set(&self) -> FileIconSet {
|
||||
self.set
|
||||
}
|
||||
|
||||
/// How the icon set was chosen.
|
||||
pub fn detection(&self) -> IconDetection {
|
||||
self.detection
|
||||
}
|
||||
|
||||
/// Human readable label for status lines.
|
||||
pub fn status_label(&self) -> &'static str {
|
||||
self.set.label()
|
||||
}
|
||||
|
||||
/// Short label indicating where the decision originated.
|
||||
pub fn detection_label(&self) -> &'static str {
|
||||
match self.detection {
|
||||
IconDetection::Configured => "config",
|
||||
IconDetection::Environment => "env",
|
||||
IconDetection::Heuristic => "auto",
|
||||
}
|
||||
}
|
||||
|
||||
/// Select the glyph to render for the given node.
|
||||
pub fn icon_for(&self, node: &FileNode) -> &'static str {
|
||||
match self.set {
|
||||
FileIconSet::Nerd => nerd_icon_for(node),
|
||||
FileIconSet::Ascii => ascii_icon_for(node),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn detect_icon_set() -> (FileIconSet, IconDetection) {
|
||||
if let Some(set) = env_icon_override() {
|
||||
return (set, IconDetection::Environment);
|
||||
}
|
||||
|
||||
if !locale_supports_unicode() || is_basic_terminal() {
|
||||
return (FileIconSet::Ascii, IconDetection::Heuristic);
|
||||
}
|
||||
|
||||
if nerd_glyph_has_compact_width() {
|
||||
(FileIconSet::Nerd, IconDetection::Heuristic)
|
||||
} else {
|
||||
(FileIconSet::Ascii, IconDetection::Heuristic)
|
||||
}
|
||||
}
|
||||
|
||||
fn env_icon_override() -> Option<FileIconSet> {
|
||||
let value = env::var(ENV_ICON_OVERRIDE).ok()?;
|
||||
match value.trim().to_ascii_lowercase().as_str() {
|
||||
"nerd" | "nerdfont" | "nf" | "fancy" => Some(FileIconSet::Nerd),
|
||||
"ascii" | "plain" | "simple" => Some(FileIconSet::Ascii),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn locale_supports_unicode() -> bool {
|
||||
let vars = ["LC_ALL", "LC_CTYPE", "LANG"];
|
||||
vars.iter()
|
||||
.filter_map(|name| env::var(name).ok())
|
||||
.map(|value| value.to_ascii_lowercase())
|
||||
.any(|value| value.contains("utf-8") || value.contains("utf8"))
|
||||
}
|
||||
|
||||
fn is_basic_terminal() -> bool {
|
||||
matches!(env::var("TERM").ok().as_deref(), Some("linux" | "vt100"))
|
||||
}
|
||||
|
||||
fn nerd_glyph_has_compact_width() -> bool {
|
||||
// Sample glyphs chosen from the Nerd Font private use area.
|
||||
const SAMPLE_ICONS: [&str; 3] = ["", "", ""];
|
||||
SAMPLE_ICONS.iter().all(|icon| {
|
||||
icon.chars()
|
||||
.all(|ch| UnicodeWidthChar::width(ch).unwrap_or(1) == 1)
|
||||
})
|
||||
}
|
||||
|
||||
fn nerd_icon_for(node: &FileNode) -> &'static str {
|
||||
if node.depth == 0 {
|
||||
return "";
|
||||
}
|
||||
if node.is_dir {
|
||||
return if node.is_expanded { "" } else { "" };
|
||||
}
|
||||
|
||||
let name = node.name.as_str();
|
||||
if let Some(icon) = nerd_icon_by_special_name(name) {
|
||||
return icon;
|
||||
}
|
||||
|
||||
let ext = Path::new(name)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match ext.as_str() {
|
||||
"rs" => "",
|
||||
"toml" => "",
|
||||
"lock" => "",
|
||||
"json" => "",
|
||||
"yaml" | "yml" => "",
|
||||
"md" | "markdown" => "",
|
||||
"py" => "",
|
||||
"rb" => "",
|
||||
"go" => "",
|
||||
"sh" | "bash" => "",
|
||||
"zsh" => "",
|
||||
"fish" => "",
|
||||
"ts" => "",
|
||||
"tsx" => "",
|
||||
"js" => "",
|
||||
"jsx" => "",
|
||||
"mjs" | "cjs" => "",
|
||||
"html" | "htm" => "",
|
||||
"css" => "",
|
||||
"scss" | "sass" => "",
|
||||
"less" => "",
|
||||
"vue" => "",
|
||||
"svelte" => "",
|
||||
"java" => "",
|
||||
"kt" => "",
|
||||
"swift" => "",
|
||||
"c" => "",
|
||||
"h" => "",
|
||||
"cpp" | "cxx" | "cc" => "",
|
||||
"hpp" | "hh" | "hxx" => "",
|
||||
"cs" => "",
|
||||
"php" => "",
|
||||
"zig" => "",
|
||||
"lua" => "",
|
||||
"sql" => "",
|
||||
"erl" | "hrl" => "",
|
||||
"ex" | "exs" => "",
|
||||
"hs" => "",
|
||||
"scala" => "",
|
||||
"dart" => "",
|
||||
"gradle" => "",
|
||||
"groovy" => "",
|
||||
"xml" => "",
|
||||
"ini" | "cfg" => "",
|
||||
"env" => "",
|
||||
"log" => "",
|
||||
"txt" => "",
|
||||
"pdf" => "",
|
||||
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => "",
|
||||
"svg" => "",
|
||||
"ico" => "",
|
||||
"lockb" => "",
|
||||
"wasm" => "",
|
||||
_ => "",
|
||||
}
|
||||
}
|
||||
|
||||
fn nerd_icon_by_special_name(name: &str) -> Option<&'static str> {
|
||||
match name {
|
||||
"Cargo.toml" => Some(""),
|
||||
"Cargo.lock" => Some(""),
|
||||
"Makefile" | "makefile" => Some(""),
|
||||
"Dockerfile" => Some(""),
|
||||
".gitignore" => Some(""),
|
||||
".gitmodules" => Some(""),
|
||||
"README.md" | "readme.md" => Some(""),
|
||||
"LICENSE" | "LICENSE.md" | "LICENSE.txt" => Some(""),
|
||||
"package.json" => Some(""),
|
||||
"package-lock.json" => Some(""),
|
||||
"yarn.lock" => Some(""),
|
||||
"pnpm-lock.yaml" | "pnpm-lock.yml" => Some(""),
|
||||
"tsconfig.json" => Some(""),
|
||||
"config.toml" => Some(""),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ascii_icon_for(node: &FileNode) -> &'static str {
|
||||
if node.depth == 0 {
|
||||
return "[]";
|
||||
}
|
||||
if node.is_dir {
|
||||
return if node.is_expanded { "[]" } else { "<>" };
|
||||
}
|
||||
|
||||
let name = node.name.as_str();
|
||||
if let Some(icon) = ascii_icon_by_special_name(name) {
|
||||
return icon;
|
||||
}
|
||||
|
||||
let ext = Path::new(name)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.unwrap_or_default()
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match ext.as_str() {
|
||||
"rs" => "RS",
|
||||
"toml" => "TL",
|
||||
"lock" => "LK",
|
||||
"json" => "JS",
|
||||
"yaml" | "yml" => "YM",
|
||||
"md" | "markdown" => "MD",
|
||||
"py" => "PY",
|
||||
"rb" => "RB",
|
||||
"go" => "GO",
|
||||
"sh" | "bash" | "zsh" | "fish" => "SH",
|
||||
"ts" => "TS",
|
||||
"tsx" => "TX",
|
||||
"js" | "jsx" | "mjs" | "cjs" => "JS",
|
||||
"html" | "htm" => "HT",
|
||||
"css" => "CS",
|
||||
"scss" | "sass" => "SC",
|
||||
"vue" => "VU",
|
||||
"svelte" => "SV",
|
||||
"java" => "JV",
|
||||
"kt" => "KT",
|
||||
"swift" => "SW",
|
||||
"c" => "C",
|
||||
"h" => "H",
|
||||
"cpp" | "cxx" | "cc" => "C+",
|
||||
"hpp" | "hh" | "hxx" => "H+",
|
||||
"cs" => "CS",
|
||||
"php" => "PH",
|
||||
"zig" => "ZG",
|
||||
"lua" => "LU",
|
||||
"sql" => "SQ",
|
||||
"erl" | "hrl" => "ER",
|
||||
"ex" | "exs" => "EX",
|
||||
"hs" => "HS",
|
||||
"scala" => "SC",
|
||||
"dart" => "DT",
|
||||
"gradle" => "GR",
|
||||
"groovy" => "GR",
|
||||
"xml" => "XM",
|
||||
"ini" | "cfg" => "CF",
|
||||
"env" => "EV",
|
||||
"log" => "LG",
|
||||
"txt" => "--",
|
||||
"pdf" => "PD",
|
||||
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => "IM",
|
||||
"svg" => "SG",
|
||||
"wasm" => "WM",
|
||||
_ => "--",
|
||||
}
|
||||
}
|
||||
|
||||
fn ascii_icon_by_special_name(name: &str) -> Option<&'static str> {
|
||||
match name {
|
||||
"Cargo.toml" => Some("TL"),
|
||||
"Cargo.lock" => Some("LK"),
|
||||
"Makefile" | "makefile" => Some("MK"),
|
||||
"Dockerfile" => Some("DK"),
|
||||
".gitignore" => Some("GI"),
|
||||
".gitmodules" => Some("GI"),
|
||||
"README.md" | "readme.md" => Some("MD"),
|
||||
"LICENSE" | "LICENSE.md" | "LICENSE.txt" => Some("LC"),
|
||||
"package.json" => Some("PJ"),
|
||||
"package-lock.json" => Some("PL"),
|
||||
"yarn.lock" => Some("YL"),
|
||||
"pnpm-lock.yaml" | "pnpm-lock.yml" => Some("PL"),
|
||||
"tsconfig.json" => Some("TC"),
|
||||
"config.toml" => Some("CF"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
723
crates/owlen-tui/src/state/file_tree.rs
Normal file
723
crates/owlen-tui/src/state/file_tree.rs
Normal file
@@ -0,0 +1,723 @@
|
||||
use crate::commands;
|
||||
use anyhow::{Context, Result};
|
||||
use globset::{Glob, GlobBuilder, GlobSetBuilder};
|
||||
use ignore::WalkBuilder;
|
||||
use pathdiff::diff_paths;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsStr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
|
||||
/// Indicates which matching strategy is applied when filtering the file tree.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum FilterMode {
|
||||
Glob,
|
||||
Fuzzy,
|
||||
}
|
||||
|
||||
/// Git-related decorations rendered alongside a file entry.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitDecoration {
|
||||
pub badge: Option<char>,
|
||||
pub cleanliness: char,
|
||||
}
|
||||
|
||||
impl GitDecoration {
|
||||
pub fn clean() -> Self {
|
||||
Self {
|
||||
badge: None,
|
||||
cleanliness: '✓',
|
||||
}
|
||||
}
|
||||
|
||||
pub fn staged(badge: Option<char>) -> Self {
|
||||
Self {
|
||||
badge,
|
||||
cleanliness: '○',
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dirty(badge: Option<char>) -> Self {
|
||||
Self {
|
||||
badge,
|
||||
cleanliness: '●',
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Node representing a single entry (file or directory) in the tree.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileNode {
|
||||
pub name: String,
|
||||
pub path: PathBuf,
|
||||
pub parent: Option<usize>,
|
||||
pub children: Vec<usize>,
|
||||
pub depth: usize,
|
||||
pub is_dir: bool,
|
||||
pub is_expanded: bool,
|
||||
pub is_hidden: bool,
|
||||
pub git: GitDecoration,
|
||||
}
|
||||
|
||||
impl FileNode {
|
||||
fn should_default_expand(&self) -> bool {
|
||||
self.depth < 2
|
||||
}
|
||||
}
|
||||
|
||||
/// Visible entry metadata returned to the renderer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct VisibleFileEntry {
|
||||
pub index: usize,
|
||||
pub depth: usize,
|
||||
}
|
||||
|
||||
/// Tracks the entire file tree state including filters, selection, and scroll.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FileTreeState {
|
||||
root: PathBuf,
|
||||
repo_name: String,
|
||||
nodes: Vec<FileNode>,
|
||||
visible: Vec<VisibleFileEntry>,
|
||||
cursor: usize,
|
||||
scroll_top: usize,
|
||||
viewport_height: usize,
|
||||
filter_mode: FilterMode,
|
||||
filter_query: String,
|
||||
show_hidden: bool,
|
||||
filter_matches: Vec<bool>,
|
||||
last_error: Option<String>,
|
||||
git_branch: Option<String>,
|
||||
}
|
||||
|
||||
impl FileTreeState {
|
||||
/// Construct a new file tree rooted at the provided path.
|
||||
pub fn new(root: impl Into<PathBuf>) -> Self {
|
||||
let mut root_path = root.into();
|
||||
if let Ok(canonical) = root_path.canonicalize() {
|
||||
root_path = canonical;
|
||||
}
|
||||
let repo_name = root_path
|
||||
.file_name()
|
||||
.map(|s| s.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| root_path.display().to_string());
|
||||
|
||||
let mut state = Self {
|
||||
root: root_path,
|
||||
repo_name,
|
||||
nodes: Vec::new(),
|
||||
visible: Vec::new(),
|
||||
cursor: 0,
|
||||
scroll_top: 0,
|
||||
viewport_height: 20,
|
||||
filter_mode: FilterMode::Glob,
|
||||
filter_query: String::new(),
|
||||
show_hidden: false,
|
||||
filter_matches: Vec::new(),
|
||||
last_error: None,
|
||||
git_branch: None,
|
||||
};
|
||||
|
||||
if let Err(err) = state.refresh() {
|
||||
state.nodes.clear();
|
||||
state.visible.clear();
|
||||
state.filter_matches.clear();
|
||||
state.last_error = Some(err.to_string());
|
||||
}
|
||||
|
||||
state
|
||||
}
|
||||
|
||||
/// Rebuild the file tree from disk and recompute visibility.
|
||||
pub fn refresh(&mut self) -> Result<()> {
|
||||
let git_map = collect_git_status(&self.root).unwrap_or_default();
|
||||
self.nodes = build_nodes(&self.root, self.show_hidden, git_map)?;
|
||||
self.git_branch = current_git_branch(&self.root).unwrap_or(None);
|
||||
if self.nodes.is_empty() {
|
||||
self.visible.clear();
|
||||
self.filter_matches.clear();
|
||||
self.cursor = 0;
|
||||
return Ok(());
|
||||
}
|
||||
self.ensure_valid_cursor();
|
||||
self.recompute_filter_cache();
|
||||
self.rebuild_visible();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn repo_name(&self) -> &str {
|
||||
&self.repo_name
|
||||
}
|
||||
|
||||
pub fn root(&self) -> &Path {
|
||||
&self.root
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.visible.is_empty()
|
||||
}
|
||||
|
||||
pub fn visible_entries(&self) -> &[VisibleFileEntry] {
|
||||
&self.visible
|
||||
}
|
||||
|
||||
pub fn nodes(&self) -> &[FileNode] {
|
||||
&self.nodes
|
||||
}
|
||||
|
||||
pub fn selected_index(&self) -> Option<usize> {
|
||||
self.visible.get(self.cursor).map(|entry| entry.index)
|
||||
}
|
||||
|
||||
pub fn selected_node(&self) -> Option<&FileNode> {
|
||||
self.selected_index().and_then(|idx| self.nodes.get(idx))
|
||||
}
|
||||
|
||||
pub fn selected_node_mut(&mut self) -> Option<&mut FileNode> {
|
||||
let idx = self.selected_index()?;
|
||||
self.nodes.get_mut(idx)
|
||||
}
|
||||
|
||||
pub fn cursor(&self) -> usize {
|
||||
self.cursor
|
||||
}
|
||||
|
||||
pub fn scroll_top(&self) -> usize {
|
||||
self.scroll_top
|
||||
}
|
||||
|
||||
pub fn viewport_height(&self) -> usize {
|
||||
self.viewport_height
|
||||
}
|
||||
|
||||
pub fn filter_mode(&self) -> FilterMode {
|
||||
self.filter_mode
|
||||
}
|
||||
|
||||
pub fn filter_query(&self) -> &str {
|
||||
&self.filter_query
|
||||
}
|
||||
|
||||
pub fn set_filter_mode(&mut self, mode: FilterMode) {
|
||||
if self.filter_mode != mode {
|
||||
self.filter_mode = mode;
|
||||
self.recompute_filter_cache();
|
||||
self.rebuild_visible();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn show_hidden(&self) -> bool {
|
||||
self.show_hidden
|
||||
}
|
||||
|
||||
pub fn git_branch(&self) -> Option<&str> {
|
||||
self.git_branch.as_deref()
|
||||
}
|
||||
|
||||
pub fn last_error(&self) -> Option<&str> {
|
||||
self.last_error.as_deref()
|
||||
}
|
||||
|
||||
pub fn set_viewport_height(&mut self, height: usize) {
|
||||
self.viewport_height = height.max(1);
|
||||
self.ensure_cursor_in_view();
|
||||
}
|
||||
|
||||
pub fn move_cursor(&mut self, delta: isize) {
|
||||
if self.visible.is_empty() {
|
||||
self.cursor = 0;
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let len = self.visible.len() as isize;
|
||||
let new_cursor = (self.cursor as isize + delta).clamp(0, len - 1) as usize;
|
||||
self.cursor = new_cursor;
|
||||
self.ensure_cursor_in_view();
|
||||
}
|
||||
|
||||
pub fn jump_to_top(&mut self) {
|
||||
if !self.visible.is_empty() {
|
||||
self.cursor = 0;
|
||||
self.scroll_top = 0;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn jump_to_bottom(&mut self) {
|
||||
if !self.visible.is_empty() {
|
||||
self.cursor = self.visible.len().saturating_sub(1);
|
||||
let viewport = self.viewport_height.max(1);
|
||||
self.scroll_top = self.visible.len().saturating_sub(viewport);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn page_down(&mut self) {
|
||||
let amount = self.viewport_height.max(1) as isize;
|
||||
self.move_cursor(amount);
|
||||
}
|
||||
|
||||
pub fn page_up(&mut self) {
|
||||
let amount = -(self.viewport_height.max(1) as isize);
|
||||
self.move_cursor(amount);
|
||||
}
|
||||
|
||||
pub fn toggle_expand(&mut self) {
|
||||
if let Some(node) = self.selected_node_mut() {
|
||||
if !node.is_dir {
|
||||
return;
|
||||
}
|
||||
node.is_expanded = !node.is_expanded;
|
||||
self.rebuild_visible();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_filter_query(&mut self, query: impl Into<String>) {
|
||||
self.filter_query = query.into();
|
||||
self.recompute_filter_cache();
|
||||
self.rebuild_visible();
|
||||
}
|
||||
|
||||
pub fn clear_filter(&mut self) {
|
||||
self.filter_query.clear();
|
||||
self.recompute_filter_cache();
|
||||
self.rebuild_visible();
|
||||
}
|
||||
|
||||
pub fn toggle_filter_mode(&mut self) {
|
||||
let next = match self.filter_mode {
|
||||
FilterMode::Glob => FilterMode::Fuzzy,
|
||||
FilterMode::Fuzzy => FilterMode::Glob,
|
||||
};
|
||||
self.set_filter_mode(next);
|
||||
}
|
||||
|
||||
pub fn toggle_hidden(&mut self) -> Result<()> {
|
||||
self.show_hidden = !self.show_hidden;
|
||||
self.refresh()
|
||||
}
|
||||
|
||||
/// Expand directories along the provided path and position the cursor.
|
||||
pub fn reveal(&mut self, path: &Path) {
|
||||
if self.nodes.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(rel) = diff_paths(path, &self.root) {
|
||||
if let Some(index) = self
|
||||
.nodes
|
||||
.iter()
|
||||
.position(|node| node.path == rel || node.path == path)
|
||||
{
|
||||
self.expand_to(index);
|
||||
if let Some(cursor_pos) = self.visible.iter().position(|entry| entry.index == index)
|
||||
{
|
||||
self.cursor = cursor_pos;
|
||||
self.ensure_cursor_in_view();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn expand_to(&mut self, index: usize) {
|
||||
let mut current = Some(index);
|
||||
while let Some(idx) = current {
|
||||
if let Some(parent) = self.nodes.get(idx).and_then(|node| node.parent) {
|
||||
if let Some(parent_node) = self.nodes.get_mut(parent) {
|
||||
parent_node.is_expanded = true;
|
||||
}
|
||||
current = Some(parent);
|
||||
} else {
|
||||
current = None;
|
||||
}
|
||||
}
|
||||
self.rebuild_visible();
|
||||
}
|
||||
|
||||
fn ensure_valid_cursor(&mut self) {
|
||||
if self.cursor >= self.visible.len() {
|
||||
self.cursor = self.visible.len().saturating_sub(1);
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_cursor_in_view(&mut self) {
|
||||
if self.visible.is_empty() {
|
||||
self.cursor = 0;
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let viewport = self.viewport_height.max(1);
|
||||
if self.cursor < self.scroll_top {
|
||||
self.scroll_top = self.cursor;
|
||||
} else if self.cursor >= self.scroll_top + viewport {
|
||||
self.scroll_top = self.cursor + 1 - viewport;
|
||||
}
|
||||
}
|
||||
|
||||
fn recompute_filter_cache(&mut self) {
|
||||
let has_filter = !self.filter_query.trim().is_empty();
|
||||
self.filter_matches = if !has_filter {
|
||||
vec![true; self.nodes.len()]
|
||||
} else {
|
||||
self.nodes
|
||||
.iter()
|
||||
.map(|node| match self.filter_mode {
|
||||
FilterMode::Glob => glob_matches(self.filter_query.trim(), node),
|
||||
FilterMode::Fuzzy => fuzzy_matches(self.filter_query.trim(), node),
|
||||
})
|
||||
.collect()
|
||||
};
|
||||
|
||||
if has_filter {
|
||||
// Ensure parent directories of matches are preserved.
|
||||
for idx in (0..self.nodes.len()).rev() {
|
||||
let children = self.nodes[idx].children.clone();
|
||||
if !self.filter_matches[idx]
|
||||
&& children
|
||||
.iter()
|
||||
.any(|child| self.filter_matches.get(*child).copied().unwrap_or(false))
|
||||
{
|
||||
self.filter_matches[idx] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rebuild_visible(&mut self) {
|
||||
self.visible.clear();
|
||||
|
||||
if self.nodes.is_empty() {
|
||||
self.cursor = 0;
|
||||
self.scroll_top = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
let has_filter = !self.filter_query.trim().is_empty();
|
||||
self.walk_visible(0, has_filter);
|
||||
if self.visible.is_empty() {
|
||||
// At minimum show the root node.
|
||||
self.visible.push(VisibleFileEntry {
|
||||
index: 0,
|
||||
depth: self.nodes[0].depth,
|
||||
});
|
||||
}
|
||||
let max_index = self.visible.len().saturating_sub(1);
|
||||
self.cursor = self.cursor.min(max_index);
|
||||
self.ensure_cursor_in_view();
|
||||
}
|
||||
|
||||
fn walk_visible(&mut self, index: usize, filter_override: bool) {
|
||||
if !self.filter_matches.get(index).copied().unwrap_or(true) {
|
||||
return;
|
||||
}
|
||||
|
||||
let (depth, descend, children) = {
|
||||
let node = match self.nodes.get(index) {
|
||||
Some(node) => node,
|
||||
None => return,
|
||||
};
|
||||
let descend = if filter_override {
|
||||
node.is_dir
|
||||
} else {
|
||||
node.is_dir && node.is_expanded
|
||||
};
|
||||
let children = if node.is_dir {
|
||||
node.children.clone()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
(node.depth, descend, children)
|
||||
};
|
||||
|
||||
self.visible.push(VisibleFileEntry { index, depth });
|
||||
|
||||
if descend {
|
||||
for child in children {
|
||||
self.walk_visible(child, filter_override);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn glob_matches(pattern: &str, node: &FileNode) -> bool {
|
||||
if pattern.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let mut builder = GlobSetBuilder::new();
|
||||
match GlobBuilder::new(pattern).literal_separator(true).build() {
|
||||
Ok(glob) => {
|
||||
builder.add(glob);
|
||||
if let Ok(set) = builder.build() {
|
||||
return set.is_match(&node.path) || set.is_match(node.name.as_str());
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
if let Ok(glob) = Glob::new("**") {
|
||||
builder.add(glob);
|
||||
if let Ok(set) = builder.build() {
|
||||
return set.is_match(&node.path);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn fuzzy_matches(query: &str, node: &FileNode) -> bool {
|
||||
if query.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
let path_str = node.path.to_string_lossy();
|
||||
let name = node.name.as_str();
|
||||
|
||||
commands::match_score(&path_str, query)
|
||||
.or_else(|| commands::match_score(name, query))
|
||||
.is_some()
|
||||
}
|
||||
|
||||
fn build_nodes(
|
||||
root: &Path,
|
||||
show_hidden: bool,
|
||||
git_map: HashMap<PathBuf, GitDecoration>,
|
||||
) -> Result<Vec<FileNode>> {
|
||||
let mut builder = WalkBuilder::new(root);
|
||||
builder.hidden(!show_hidden);
|
||||
builder.git_global(true);
|
||||
builder.git_ignore(true);
|
||||
builder.git_exclude(true);
|
||||
builder.follow_links(false);
|
||||
builder.sort_by_file_path(|a, b| a.file_name().cmp(&b.file_name()));
|
||||
|
||||
let owlen_ignore = root.join(".owlenignore");
|
||||
if owlen_ignore.exists() {
|
||||
builder.add_ignore(&owlen_ignore);
|
||||
}
|
||||
|
||||
let mut nodes: Vec<FileNode> = Vec::new();
|
||||
let mut index_by_path: HashMap<PathBuf, usize> = HashMap::new();
|
||||
|
||||
for result in builder.build() {
|
||||
let entry = match result {
|
||||
Ok(value) => value,
|
||||
Err(err) => {
|
||||
eprintln!("File tree walk error: {err}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Skip errors or entries without metadata.
|
||||
let file_type = match entry.file_type() {
|
||||
Some(ft) => ft,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let depth = entry.depth();
|
||||
if depth == 0 && !file_type.is_dir() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let relative = if depth == 0 {
|
||||
PathBuf::new()
|
||||
} else {
|
||||
diff_paths(entry.path(), root).unwrap_or_else(|| entry.path().to_path_buf())
|
||||
};
|
||||
|
||||
let name = if depth == 0 {
|
||||
root.file_name()
|
||||
.map(|s| s.to_string_lossy().into_owned())
|
||||
.unwrap_or_else(|| root.display().to_string())
|
||||
} else {
|
||||
entry.file_name().to_string_lossy().into_owned()
|
||||
};
|
||||
|
||||
let parent = if depth == 0 {
|
||||
None
|
||||
} else {
|
||||
entry
|
||||
.path()
|
||||
.parent()
|
||||
.and_then(|parent| diff_paths(parent, root))
|
||||
.and_then(|rel_parent| index_by_path.get(&rel_parent).copied())
|
||||
};
|
||||
|
||||
let git = git_map
|
||||
.get(&relative)
|
||||
.cloned()
|
||||
.unwrap_or_else(GitDecoration::clean);
|
||||
|
||||
let mut node = FileNode {
|
||||
name,
|
||||
path: relative.clone(),
|
||||
parent,
|
||||
children: Vec::new(),
|
||||
depth,
|
||||
is_dir: file_type.is_dir(),
|
||||
is_expanded: false,
|
||||
is_hidden: is_hidden(entry.file_name()),
|
||||
git,
|
||||
};
|
||||
|
||||
node.is_expanded = node.should_default_expand();
|
||||
|
||||
let index = nodes.len();
|
||||
if let Some(parent_idx) = parent {
|
||||
if let Some(parent_node) = nodes.get_mut(parent_idx) {
|
||||
parent_node.children.push(index);
|
||||
}
|
||||
}
|
||||
|
||||
index_by_path.insert(relative, index);
|
||||
nodes.push(node);
|
||||
}
|
||||
|
||||
propagate_directory_git_state(&mut nodes);
|
||||
Ok(nodes)
|
||||
}
|
||||
|
||||
fn is_hidden(name: &OsStr) -> bool {
|
||||
name.to_string_lossy().starts_with('.')
|
||||
}
|
||||
|
||||
fn propagate_directory_git_state(nodes: &mut [FileNode]) {
|
||||
for idx in (0..nodes.len()).rev() {
|
||||
if !nodes[idx].is_dir {
|
||||
continue;
|
||||
}
|
||||
let mut has_dirty = false;
|
||||
let mut dirty_badge: Option<char> = None;
|
||||
let mut has_staged = false;
|
||||
for child in nodes[idx].children.clone() {
|
||||
if let Some(child_node) = nodes.get(child) {
|
||||
match child_node.git.cleanliness {
|
||||
'●' => {
|
||||
has_dirty = true;
|
||||
let candidate = child_node.git.badge.unwrap_or('M');
|
||||
dirty_badge = Some(match (dirty_badge, candidate) {
|
||||
(Some('D'), _) | (_, 'D') => 'D',
|
||||
(Some('U'), _) | (_, 'U') => 'U',
|
||||
(Some(existing), _) => existing,
|
||||
(None, new_badge) => new_badge,
|
||||
});
|
||||
}
|
||||
'○' => {
|
||||
has_staged = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nodes[idx].git = if has_dirty {
|
||||
GitDecoration::dirty(dirty_badge)
|
||||
} else if has_staged {
|
||||
GitDecoration::staged(None)
|
||||
} else {
|
||||
GitDecoration::clean()
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_git_status(root: &Path) -> Result<HashMap<PathBuf, GitDecoration>> {
|
||||
if !root.join(".git").exists() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let output = Command::new("git")
|
||||
.arg("-C")
|
||||
.arg(root)
|
||||
.arg("status")
|
||||
.arg("--porcelain")
|
||||
.output()
|
||||
.with_context(|| format!("Failed to run git status in {}", root.display()))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(HashMap::new());
|
||||
}
|
||||
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
let mut map = HashMap::new();
|
||||
|
||||
for line in stdout.lines() {
|
||||
if line.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut chars = line.chars();
|
||||
let x = chars.next().unwrap_or(' ');
|
||||
let y = chars.next().unwrap_or(' ');
|
||||
if x == '!' || y == '!' {
|
||||
// ignored entry
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut path_part = line[3..].trim();
|
||||
if let Some(idx) = path_part.rfind(" -> ") {
|
||||
path_part = &path_part[idx + 4..];
|
||||
}
|
||||
|
||||
let path = PathBuf::from(path_part);
|
||||
|
||||
if let Some(decoration) = decode_git_status(x, y) {
|
||||
map.insert(path, decoration);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
fn current_git_branch(root: &Path) -> Result<Option<String>> {
|
||||
if !root.join(".git").exists() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let output = Command::new("git")
|
||||
.arg("-C")
|
||||
.arg(root)
|
||||
.arg("rev-parse")
|
||||
.arg("--abbrev-ref")
|
||||
.arg("HEAD")
|
||||
.output()
|
||||
.with_context(|| format!("Failed to query git branch in {}", root.display()))?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
if branch.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(branch))
|
||||
}
|
||||
}
|
||||
|
||||
fn decode_git_status(x: char, y: char) -> Option<GitDecoration> {
|
||||
if x == ' ' && y == ' ' {
|
||||
return Some(GitDecoration::clean());
|
||||
}
|
||||
|
||||
if x == '?' && y == '?' {
|
||||
return Some(GitDecoration::dirty(Some('A')));
|
||||
}
|
||||
|
||||
let badge = match (x, y) {
|
||||
('M', _) | (_, 'M') => Some('M'),
|
||||
('A', _) | (_, 'A') => Some('A'),
|
||||
('D', _) | (_, 'D') => Some('D'),
|
||||
('R', _) | (_, 'R') => Some('R'),
|
||||
('C', _) | (_, 'C') => Some('A'),
|
||||
('U', _) | (_, 'U') => Some('U'),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
if y != ' ' {
|
||||
Some(GitDecoration::dirty(badge))
|
||||
} else if x != ' ' {
|
||||
Some(GitDecoration::staged(badge))
|
||||
} else {
|
||||
Some(GitDecoration::clean())
|
||||
}
|
||||
}
|
||||
307
crates/owlen-tui/src/state/keymap.rs
Normal file
307
crates/owlen-tui/src/state/keymap.rs
Normal file
@@ -0,0 +1,307 @@
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use log::warn;
|
||||
use owlen_core::{config::default_config_path, ui::InputMode};
|
||||
use serde::Deserialize;
|
||||
|
||||
use crate::commands::registry::{AppCommand, CommandRegistry};
|
||||
|
||||
const DEFAULT_KEYMAP: &str = include_str!("../../keymap.toml");
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Keymap {
|
||||
bindings: HashMap<(InputMode, KeyPattern), AppCommand>,
|
||||
}
|
||||
|
||||
impl Keymap {
|
||||
pub fn load(custom_path: Option<&str>, registry: &CommandRegistry) -> Self {
|
||||
let mut content = None;
|
||||
|
||||
if let Some(path) = custom_path.and_then(expand_path) {
|
||||
if let Ok(text) = fs::read_to_string(&path) {
|
||||
content = Some(text);
|
||||
} else {
|
||||
warn!(
|
||||
"Failed to read keymap from {}. Falling back to defaults.",
|
||||
path.display()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if content.is_none() {
|
||||
let default_path = default_config_keymap_path();
|
||||
if let Some(path) = default_path {
|
||||
if let Ok(text) = fs::read_to_string(&path) {
|
||||
content = Some(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let data = content.unwrap_or_else(|| DEFAULT_KEYMAP.to_string());
|
||||
let parsed: KeymapConfig = toml::from_str(&data).unwrap_or_else(|err| {
|
||||
warn!("Failed to parse keymap: {err}. Using built-in defaults.");
|
||||
toml::from_str(DEFAULT_KEYMAP).expect("embedded keymap should parse successfully")
|
||||
});
|
||||
|
||||
let mut bindings = HashMap::new();
|
||||
|
||||
for entry in parsed.bindings {
|
||||
let mode = match parse_mode(&entry.mode) {
|
||||
Some(mode) => mode,
|
||||
None => {
|
||||
warn!("Unknown input mode '{}' in keymap binding", entry.mode);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let command = match registry.resolve(&entry.command) {
|
||||
Some(cmd) => cmd,
|
||||
None => {
|
||||
warn!("Unknown command '{}' in keymap binding", entry.command);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for key in entry.keys.into_iter() {
|
||||
match KeyPattern::from_str(&key) {
|
||||
Some(pattern) => {
|
||||
bindings.insert((mode, pattern), command);
|
||||
}
|
||||
None => warn!(
|
||||
"Unrecognised key specification '{}' for mode {}",
|
||||
key, entry.mode
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self { bindings }
|
||||
}
|
||||
|
||||
pub fn resolve(&self, mode: InputMode, event: &KeyEvent) -> Option<AppCommand> {
|
||||
let pattern = KeyPattern::from_event(event)?;
|
||||
self.bindings.get(&(mode, pattern)).copied()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct KeymapConfig {
|
||||
#[serde(default, rename = "binding")]
|
||||
bindings: Vec<KeyBindingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct KeyBindingConfig {
|
||||
mode: String,
|
||||
command: String,
|
||||
keys: KeyList,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum KeyList {
|
||||
Single(String),
|
||||
Multiple(Vec<String>),
|
||||
}
|
||||
|
||||
impl KeyList {
|
||||
fn into_iter(self) -> Vec<String> {
|
||||
match self {
|
||||
KeyList::Single(key) => vec![key],
|
||||
KeyList::Multiple(keys) => keys,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
struct KeyPattern {
|
||||
code: KeyCodeKind,
|
||||
modifiers: KeyModifiers,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
enum KeyCodeKind {
|
||||
Char(char),
|
||||
Enter,
|
||||
Tab,
|
||||
BackTab,
|
||||
Backspace,
|
||||
Esc,
|
||||
Up,
|
||||
Down,
|
||||
Left,
|
||||
Right,
|
||||
PageUp,
|
||||
PageDown,
|
||||
Home,
|
||||
End,
|
||||
F(u8),
|
||||
}
|
||||
|
||||
impl KeyPattern {
|
||||
fn from_event(event: &KeyEvent) -> Option<Self> {
|
||||
let code = match event.code {
|
||||
KeyCode::Char(c) => KeyCodeKind::Char(c),
|
||||
KeyCode::Enter => KeyCodeKind::Enter,
|
||||
KeyCode::Tab => KeyCodeKind::Tab,
|
||||
KeyCode::BackTab => KeyCodeKind::BackTab,
|
||||
KeyCode::Backspace => KeyCodeKind::Backspace,
|
||||
KeyCode::Esc => KeyCodeKind::Esc,
|
||||
KeyCode::Up => KeyCodeKind::Up,
|
||||
KeyCode::Down => KeyCodeKind::Down,
|
||||
KeyCode::Left => KeyCodeKind::Left,
|
||||
KeyCode::Right => KeyCodeKind::Right,
|
||||
KeyCode::PageUp => KeyCodeKind::PageUp,
|
||||
KeyCode::PageDown => KeyCodeKind::PageDown,
|
||||
KeyCode::Home => KeyCodeKind::Home,
|
||||
KeyCode::End => KeyCodeKind::End,
|
||||
KeyCode::F(n) => KeyCodeKind::F(n),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
Some(Self {
|
||||
code,
|
||||
modifiers: normalize_modifiers(event.modifiers),
|
||||
})
|
||||
}
|
||||
|
||||
fn from_str(spec: &str) -> Option<Self> {
|
||||
let tokens: Vec<&str> = spec
|
||||
.split('+')
|
||||
.map(|token| token.trim())
|
||||
.filter(|token| !token.is_empty())
|
||||
.collect();
|
||||
|
||||
if tokens.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut modifiers = KeyModifiers::empty();
|
||||
let key_token = tokens.last().copied().unwrap();
|
||||
|
||||
for token in tokens[..tokens.len().saturating_sub(1)].iter() {
|
||||
match token.to_ascii_lowercase().as_str() {
|
||||
"ctrl" | "control" => modifiers.insert(KeyModifiers::CONTROL),
|
||||
"alt" | "option" => modifiers.insert(KeyModifiers::ALT),
|
||||
"shift" => modifiers.insert(KeyModifiers::SHIFT),
|
||||
other => warn!("Unknown modifier '{other}' in key binding '{spec}'"),
|
||||
}
|
||||
}
|
||||
|
||||
let code = parse_key_token(key_token, &mut modifiers)?;
|
||||
|
||||
Some(Self {
|
||||
code,
|
||||
modifiers: normalize_modifiers(modifiers),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_key_token(token: &str, modifiers: &mut KeyModifiers) -> Option<KeyCodeKind> {
|
||||
let token_lower = token.to_ascii_lowercase();
|
||||
let code = match token_lower.as_str() {
|
||||
"enter" | "return" => KeyCodeKind::Enter,
|
||||
"tab" => {
|
||||
if modifiers.contains(KeyModifiers::SHIFT) {
|
||||
modifiers.remove(KeyModifiers::SHIFT);
|
||||
KeyCodeKind::BackTab
|
||||
} else {
|
||||
KeyCodeKind::Tab
|
||||
}
|
||||
}
|
||||
"backtab" => KeyCodeKind::BackTab,
|
||||
"backspace" | "bs" => KeyCodeKind::Backspace,
|
||||
"esc" | "escape" => KeyCodeKind::Esc,
|
||||
"up" => KeyCodeKind::Up,
|
||||
"down" => KeyCodeKind::Down,
|
||||
"left" => KeyCodeKind::Left,
|
||||
"right" => KeyCodeKind::Right,
|
||||
"pageup" | "page_up" | "pgup" => KeyCodeKind::PageUp,
|
||||
"pagedown" | "page_down" | "pgdn" => KeyCodeKind::PageDown,
|
||||
"home" => KeyCodeKind::Home,
|
||||
"end" => KeyCodeKind::End,
|
||||
token if token.starts_with('f') && token.len() > 1 => {
|
||||
let num = token[1..].parse::<u8>().ok()?;
|
||||
KeyCodeKind::F(num)
|
||||
}
|
||||
"space" => KeyCodeKind::Char(' '),
|
||||
"semicolon" => KeyCodeKind::Char(';'),
|
||||
"slash" => KeyCodeKind::Char('/'),
|
||||
_ => {
|
||||
let chars: Vec<char> = token.chars().collect();
|
||||
if chars.len() == 1 {
|
||||
KeyCodeKind::Char(chars[0])
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
Some(code)
|
||||
}
|
||||
|
||||
fn parse_mode(mode: &str) -> Option<InputMode> {
|
||||
match mode.to_ascii_lowercase().as_str() {
|
||||
"normal" => Some(InputMode::Normal),
|
||||
"editing" => Some(InputMode::Editing),
|
||||
"command" => Some(InputMode::Command),
|
||||
"visual" => Some(InputMode::Visual),
|
||||
"provider_selection" | "provider" => Some(InputMode::ProviderSelection),
|
||||
"model_selection" | "model" => Some(InputMode::ModelSelection),
|
||||
"help" => Some(InputMode::Help),
|
||||
"session_browser" | "sessions" => Some(InputMode::SessionBrowser),
|
||||
"theme_browser" | "themes" => Some(InputMode::ThemeBrowser),
|
||||
"repo_search" | "search" => Some(InputMode::RepoSearch),
|
||||
"symbol_search" | "symbols" => Some(InputMode::SymbolSearch),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn default_config_keymap_path() -> Option<PathBuf> {
|
||||
let config_path = default_config_path();
|
||||
let dir = config_path.parent()?;
|
||||
Some(dir.join("keymap.toml"))
|
||||
}
|
||||
|
||||
fn expand_path(path: &str) -> Option<PathBuf> {
|
||||
if path.trim().is_empty() {
|
||||
return None;
|
||||
}
|
||||
let expanded = shellexpand::tilde(path);
|
||||
let candidate = Path::new(expanded.as_ref()).to_path_buf();
|
||||
Some(candidate)
|
||||
}
|
||||
|
||||
fn normalize_modifiers(modifiers: KeyModifiers) -> KeyModifiers {
|
||||
modifiers
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crossterm::event::{KeyCode, KeyModifiers};
|
||||
|
||||
#[test]
|
||||
fn resolve_binding_from_default_keymap() {
|
||||
let registry = CommandRegistry::new();
|
||||
assert!(registry.resolve("model.open_all").is_some());
|
||||
let parsed: KeymapConfig = toml::from_str(DEFAULT_KEYMAP).unwrap();
|
||||
assert!(!parsed.bindings.is_empty());
|
||||
let keymap = Keymap::load(None, ®istry);
|
||||
|
||||
let event = KeyEvent::new(KeyCode::Char('m'), KeyModifiers::NONE);
|
||||
assert!(
|
||||
!keymap.bindings.is_empty(),
|
||||
"expected default keymap to provide bindings"
|
||||
);
|
||||
assert_eq!(
|
||||
keymap.resolve(InputMode::Normal, &event),
|
||||
Some(AppCommand::OpenModelPicker(None))
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -6,5 +6,26 @@
|
||||
//! to test in isolation.
|
||||
|
||||
mod command_palette;
|
||||
mod debug_log;
|
||||
mod file_icons;
|
||||
mod file_tree;
|
||||
mod keymap;
|
||||
mod search;
|
||||
mod workspace;
|
||||
|
||||
pub use command_palette::{CommandPalette, ModelPaletteEntry};
|
||||
pub use command_palette::{CommandPalette, ModelPaletteEntry, PaletteGroup, PaletteSuggestion};
|
||||
pub use debug_log::{DebugLogEntry, DebugLogState, install_global_logger};
|
||||
pub use file_icons::{FileIconResolver, FileIconSet, IconDetection};
|
||||
pub use file_tree::{
|
||||
FileNode, FileTreeState, FilterMode as FileFilterMode, GitDecoration, VisibleFileEntry,
|
||||
};
|
||||
pub use keymap::Keymap;
|
||||
pub use search::{
|
||||
RepoSearchFile, RepoSearchMatch, RepoSearchMessage, RepoSearchRow, RepoSearchRowKind,
|
||||
RepoSearchState, SymbolEntry, SymbolKind, SymbolSearchMessage, SymbolSearchState,
|
||||
spawn_repo_search_task, spawn_symbol_search_task,
|
||||
};
|
||||
pub use workspace::{
|
||||
CodePane, CodeWorkspace, EditorTab, LayoutNode, PaneDirection, PaneId, PaneRestoreRequest,
|
||||
SplitAxis, WorkspaceSnapshot,
|
||||
};
|
||||
|
||||
1058
crates/owlen-tui/src/state/search.rs
Normal file
1058
crates/owlen-tui/src/state/search.rs
Normal file
File diff suppressed because it is too large
Load Diff
923
crates/owlen-tui/src/state/workspace.rs
Normal file
923
crates/owlen-tui/src/state/workspace.rs
Normal file
@@ -0,0 +1,923 @@
|
||||
use std::collections::HashMap;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use owlen_core::state::AutoScroll;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Cardinal direction used for navigating between panes or resizing splits.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum PaneDirection {
|
||||
Left,
|
||||
Right,
|
||||
Up,
|
||||
Down,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum ChildSide {
|
||||
First,
|
||||
Second,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
struct PathEntry {
|
||||
axis: SplitAxis,
|
||||
side: ChildSide,
|
||||
}
|
||||
|
||||
/// Identifier assigned to each pane rendered inside a tab.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct PaneId(u64);
|
||||
|
||||
impl PaneId {
|
||||
fn next(counter: &mut u64) -> Self {
|
||||
*counter += 1;
|
||||
PaneId(*counter)
|
||||
}
|
||||
|
||||
pub fn raw(self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn from_raw(raw: u64) -> Self {
|
||||
PaneId(raw)
|
||||
}
|
||||
}
|
||||
|
||||
/// Identifier used to refer to a tab within the workspace.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct TabId(u64);
|
||||
|
||||
impl TabId {
|
||||
fn next(counter: &mut u64) -> Self {
|
||||
*counter += 1;
|
||||
TabId(*counter)
|
||||
}
|
||||
|
||||
pub fn raw(self) -> u64 {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub fn from_raw(raw: u64) -> Self {
|
||||
TabId(raw)
|
||||
}
|
||||
}
|
||||
|
||||
/// Direction used when splitting a pane.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum SplitAxis {
|
||||
/// Split horizontally to create a pane below the current one.
|
||||
Horizontal,
|
||||
/// Split vertically to create a pane to the right of the current one.
|
||||
Vertical,
|
||||
}
|
||||
|
||||
/// Layout node describing either a leaf pane or a container split.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LayoutNode {
|
||||
Leaf(PaneId),
|
||||
Split {
|
||||
axis: SplitAxis,
|
||||
ratio: f32,
|
||||
first: Box<LayoutNode>,
|
||||
second: Box<LayoutNode>,
|
||||
},
|
||||
}
|
||||
|
||||
impl LayoutNode {
|
||||
pub fn ensure_ratio_bounds(&mut self) {
|
||||
match self {
|
||||
LayoutNode::Split {
|
||||
ratio,
|
||||
first,
|
||||
second,
|
||||
..
|
||||
} => {
|
||||
*ratio = ratio.clamp(0.1, 0.9);
|
||||
first.ensure_ratio_bounds();
|
||||
second.ensure_ratio_bounds();
|
||||
}
|
||||
LayoutNode::Leaf(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn nudge_ratio(&mut self, delta: f32) {
|
||||
match self {
|
||||
LayoutNode::Split { ratio, .. } => {
|
||||
*ratio = (*ratio + delta).clamp(0.1, 0.9);
|
||||
}
|
||||
LayoutNode::Leaf(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn replace_leaf(&mut self, target: PaneId, replacement: LayoutNode) -> bool {
|
||||
match self {
|
||||
LayoutNode::Leaf(id) => {
|
||||
if *id == target {
|
||||
*self = replacement;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
LayoutNode::Split { first, second, .. } => {
|
||||
first.replace_leaf(target, replacement.clone())
|
||||
|| second.replace_leaf(target, replacement)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter_leaves<'a>(&'a self, panes: &'a HashMap<PaneId, CodePane>) -> Vec<&'a CodePane> {
|
||||
let mut collected = Vec::new();
|
||||
self.collect_leaves(panes, &mut collected);
|
||||
collected
|
||||
}
|
||||
|
||||
fn collect_leaves<'a>(
|
||||
&'a self,
|
||||
panes: &'a HashMap<PaneId, CodePane>,
|
||||
output: &mut Vec<&'a CodePane>,
|
||||
) {
|
||||
match self {
|
||||
LayoutNode::Leaf(id) => {
|
||||
if let Some(pane) = panes.get(id) {
|
||||
output.push(pane);
|
||||
}
|
||||
}
|
||||
LayoutNode::Split { first, second, .. } => {
|
||||
first.collect_leaves(panes, output);
|
||||
second.collect_leaves(panes, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to(&self, target: PaneId) -> Option<Vec<PathEntry>> {
|
||||
let mut path = Vec::new();
|
||||
if self.path_to_inner(target, &mut path) {
|
||||
Some(path)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn path_to_inner(&self, target: PaneId, path: &mut Vec<PathEntry>) -> bool {
|
||||
match self {
|
||||
LayoutNode::Leaf(id) => *id == target,
|
||||
LayoutNode::Split {
|
||||
axis,
|
||||
first,
|
||||
second,
|
||||
..
|
||||
} => {
|
||||
path.push(PathEntry {
|
||||
axis: *axis,
|
||||
side: ChildSide::First,
|
||||
});
|
||||
if first.path_to_inner(target, path) {
|
||||
return true;
|
||||
}
|
||||
path.pop();
|
||||
path.push(PathEntry {
|
||||
axis: *axis,
|
||||
side: ChildSide::Second,
|
||||
});
|
||||
if second.path_to_inner(target, path) {
|
||||
return true;
|
||||
}
|
||||
path.pop();
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn subtree(&self, path: &[PathEntry]) -> Option<&LayoutNode> {
|
||||
let mut node = self;
|
||||
for entry in path {
|
||||
match node {
|
||||
LayoutNode::Split { first, second, .. } => {
|
||||
node = match entry.side {
|
||||
ChildSide::First => first.as_ref(),
|
||||
ChildSide::Second => second.as_ref(),
|
||||
};
|
||||
}
|
||||
LayoutNode::Leaf(_) => return None,
|
||||
}
|
||||
}
|
||||
Some(node)
|
||||
}
|
||||
|
||||
fn subtree_mut(&mut self, path: &[PathEntry]) -> Option<&mut LayoutNode> {
|
||||
let mut node = self;
|
||||
for entry in path {
|
||||
match node {
|
||||
LayoutNode::Split { first, second, .. } => {
|
||||
node = match entry.side {
|
||||
ChildSide::First => first.as_mut(),
|
||||
ChildSide::Second => second.as_mut(),
|
||||
};
|
||||
}
|
||||
LayoutNode::Leaf(_) => return None,
|
||||
}
|
||||
}
|
||||
Some(node)
|
||||
}
|
||||
|
||||
fn extreme_leaf(&self, prefer_second: bool) -> Option<PaneId> {
|
||||
match self {
|
||||
LayoutNode::Leaf(id) => Some(*id),
|
||||
LayoutNode::Split { first, second, .. } => {
|
||||
if prefer_second {
|
||||
second
|
||||
.extreme_leaf(prefer_second)
|
||||
.or_else(|| first.extreme_leaf(prefer_second))
|
||||
} else {
|
||||
first
|
||||
.extreme_leaf(prefer_second)
|
||||
.or_else(|| second.extreme_leaf(prefer_second))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Renderable pane that holds file contents and scroll state.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodePane {
|
||||
pub id: PaneId,
|
||||
pub absolute_path: Option<PathBuf>,
|
||||
pub display_path: Option<String>,
|
||||
pub title: String,
|
||||
pub lines: Vec<String>,
|
||||
pub scroll: AutoScroll,
|
||||
pub viewport_height: usize,
|
||||
pub is_dirty: bool,
|
||||
pub is_staged: bool,
|
||||
}
|
||||
|
||||
impl CodePane {
|
||||
pub fn new(id: PaneId) -> Self {
|
||||
Self {
|
||||
id,
|
||||
absolute_path: None,
|
||||
display_path: None,
|
||||
title: "Untitled".to_string(),
|
||||
lines: Vec::new(),
|
||||
scroll: AutoScroll::default(),
|
||||
viewport_height: 0,
|
||||
is_dirty: false,
|
||||
is_staged: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_contents(
|
||||
&mut self,
|
||||
absolute_path: Option<PathBuf>,
|
||||
display_path: Option<String>,
|
||||
lines: Vec<String>,
|
||||
) {
|
||||
self.absolute_path = absolute_path;
|
||||
self.display_path = display_path;
|
||||
self.title = self
|
||||
.absolute_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.file_name().map(|s| s.to_string_lossy().into_owned()))
|
||||
.or_else(|| self.display_path.clone())
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
self.lines = lines;
|
||||
self.scroll = AutoScroll::default();
|
||||
self.scroll.content_len = self.lines.len();
|
||||
self.scroll.stick_to_bottom = false;
|
||||
self.scroll.scroll = 0;
|
||||
}
|
||||
|
||||
pub fn update_paths(&mut self, absolute_path: Option<PathBuf>, display_path: Option<String>) {
|
||||
self.absolute_path = absolute_path;
|
||||
self.display_path = display_path.clone();
|
||||
self.title = self
|
||||
.absolute_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.file_name().map(|s| s.to_string_lossy().into_owned()))
|
||||
.or(display_path)
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.absolute_path = None;
|
||||
self.display_path = None;
|
||||
self.title = "Untitled".to_string();
|
||||
self.lines.clear();
|
||||
self.scroll = AutoScroll::default();
|
||||
self.viewport_height = 0;
|
||||
self.is_dirty = false;
|
||||
self.is_staged = false;
|
||||
}
|
||||
|
||||
pub fn set_viewport_height(&mut self, height: usize) {
|
||||
self.viewport_height = height;
|
||||
}
|
||||
|
||||
pub fn display_path(&self) -> Option<&str> {
|
||||
self.display_path.as_deref()
|
||||
}
|
||||
|
||||
pub fn absolute_path(&self) -> Option<&Path> {
|
||||
self.absolute_path.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
/// Individual tab containing a layout tree and panes.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EditorTab {
|
||||
pub id: TabId,
|
||||
pub title: String,
|
||||
pub root: LayoutNode,
|
||||
pub panes: HashMap<PaneId, CodePane>,
|
||||
pub active: PaneId,
|
||||
}
|
||||
|
||||
impl EditorTab {
|
||||
fn new(id: TabId, title: String, pane: CodePane) -> Self {
|
||||
let active = pane.id;
|
||||
let mut panes = HashMap::new();
|
||||
panes.insert(pane.id, pane);
|
||||
Self {
|
||||
id,
|
||||
title,
|
||||
root: LayoutNode::Leaf(active),
|
||||
panes,
|
||||
active,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_pane(&self) -> Option<&CodePane> {
|
||||
self.panes.get(&self.active)
|
||||
}
|
||||
|
||||
pub fn active_pane_mut(&mut self) -> Option<&mut CodePane> {
|
||||
self.panes.get_mut(&self.active)
|
||||
}
|
||||
|
||||
pub fn set_active(&mut self, pane: PaneId) {
|
||||
if self.panes.contains_key(&pane) {
|
||||
self.active = pane;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update_title_from_active(&mut self) {
|
||||
if let Some(pane) = self.active_pane() {
|
||||
self.title = pane
|
||||
.absolute_path
|
||||
.as_ref()
|
||||
.and_then(|p| p.file_name().map(|s| s.to_string_lossy().into_owned()))
|
||||
.or_else(|| pane.display_path.clone())
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn active_path(&self) -> Option<Vec<PathEntry>> {
|
||||
self.root.path_to(self.active)
|
||||
}
|
||||
|
||||
pub fn move_focus(&mut self, direction: PaneDirection) -> bool {
|
||||
let path = match self.active_path() {
|
||||
Some(path) => path,
|
||||
None => return false,
|
||||
};
|
||||
let axis = match direction {
|
||||
PaneDirection::Left | PaneDirection::Right => SplitAxis::Vertical,
|
||||
PaneDirection::Up | PaneDirection::Down => SplitAxis::Horizontal,
|
||||
};
|
||||
|
||||
for (idx, entry) in path.iter().enumerate().rev() {
|
||||
if entry.axis != axis {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (required_side, target_side, prefer_second) = match direction {
|
||||
PaneDirection::Left => (ChildSide::Second, ChildSide::First, true),
|
||||
PaneDirection::Right => (ChildSide::First, ChildSide::Second, false),
|
||||
PaneDirection::Up => (ChildSide::Second, ChildSide::First, true),
|
||||
PaneDirection::Down => (ChildSide::First, ChildSide::Second, false),
|
||||
};
|
||||
|
||||
if entry.side != required_side {
|
||||
continue;
|
||||
}
|
||||
|
||||
let parent_path = &path[..idx];
|
||||
let Some(parent) = self.root.subtree(parent_path) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let LayoutNode::Split { first, second, .. } = parent {
|
||||
let target = match target_side {
|
||||
ChildSide::First => first.as_ref(),
|
||||
ChildSide::Second => second.as_ref(),
|
||||
};
|
||||
if let Some(pane_id) = target.extreme_leaf(prefer_second)
|
||||
&& self.panes.contains_key(&pane_id)
|
||||
{
|
||||
self.active = pane_id;
|
||||
self.update_title_from_active();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn resize_active_step(&mut self, direction: PaneDirection, amount: f32) -> Option<f32> {
|
||||
let path = self.active_path()?;
|
||||
|
||||
let axis = match direction {
|
||||
PaneDirection::Left | PaneDirection::Right => SplitAxis::Vertical,
|
||||
PaneDirection::Up | PaneDirection::Down => SplitAxis::Horizontal,
|
||||
};
|
||||
|
||||
let (idx, entry) = path
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find(|(_, entry)| entry.axis == axis)?;
|
||||
|
||||
let parent_path = &path[..idx];
|
||||
let parent = self.root.subtree_mut(parent_path)?;
|
||||
|
||||
let LayoutNode::Split { ratio, .. } = parent else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let sign = match direction {
|
||||
PaneDirection::Left => {
|
||||
if entry.side == ChildSide::First {
|
||||
1.0
|
||||
} else {
|
||||
-1.0
|
||||
}
|
||||
}
|
||||
PaneDirection::Right => {
|
||||
if entry.side == ChildSide::First {
|
||||
-1.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
PaneDirection::Up => {
|
||||
if entry.side == ChildSide::First {
|
||||
1.0
|
||||
} else {
|
||||
-1.0
|
||||
}
|
||||
}
|
||||
PaneDirection::Down => {
|
||||
if entry.side == ChildSide::First {
|
||||
-1.0
|
||||
} else {
|
||||
1.0
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let mut new_ratio = (*ratio + amount * sign).clamp(0.1, 0.9);
|
||||
if (new_ratio - *ratio).abs() < f32::EPSILON {
|
||||
return Some(self.active_share_from(entry.side, new_ratio));
|
||||
}
|
||||
*ratio = new_ratio;
|
||||
new_ratio = new_ratio.clamp(0.1, 0.9);
|
||||
Some(self.active_share_from(entry.side, new_ratio))
|
||||
}
|
||||
|
||||
pub fn snap_active_share(
|
||||
&mut self,
|
||||
direction: PaneDirection,
|
||||
desired_share: f32,
|
||||
) -> Option<f32> {
|
||||
let path = self.active_path()?;
|
||||
|
||||
let axis = match direction {
|
||||
PaneDirection::Left | PaneDirection::Right => SplitAxis::Vertical,
|
||||
PaneDirection::Up | PaneDirection::Down => SplitAxis::Horizontal,
|
||||
};
|
||||
|
||||
let (idx, entry) = path
|
||||
.iter()
|
||||
.enumerate()
|
||||
.rev()
|
||||
.find(|(_, entry)| entry.axis == axis)?;
|
||||
|
||||
let parent_path = &path[..idx];
|
||||
let parent = self.root.subtree_mut(parent_path)?;
|
||||
|
||||
let LayoutNode::Split { ratio, .. } = parent else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let mut target_ratio = match entry.side {
|
||||
ChildSide::First => desired_share,
|
||||
ChildSide::Second => 1.0 - desired_share,
|
||||
}
|
||||
.clamp(0.1, 0.9);
|
||||
|
||||
if (target_ratio - *ratio).abs() < f32::EPSILON {
|
||||
return Some(self.active_share_from(entry.side, target_ratio));
|
||||
}
|
||||
|
||||
*ratio = target_ratio;
|
||||
target_ratio = target_ratio.clamp(0.1, 0.9);
|
||||
Some(self.active_share_from(entry.side, target_ratio))
|
||||
}
|
||||
|
||||
pub fn active_share(&self) -> Option<f32> {
|
||||
let path = self.active_path()?;
|
||||
let (idx, entry) =
|
||||
path.iter().enumerate().rev().find(|(_, entry)| {
|
||||
matches!(entry.axis, SplitAxis::Horizontal | SplitAxis::Vertical)
|
||||
})?;
|
||||
let parent_path = &path[..idx];
|
||||
let parent = self.root.subtree(parent_path)?;
|
||||
if let LayoutNode::Split { ratio, .. } = parent {
|
||||
Some(self.active_share_from(entry.side, *ratio))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn active_share_from(&self, side: ChildSide, ratio: f32) -> f32 {
|
||||
match side {
|
||||
ChildSide::First => ratio,
|
||||
ChildSide::Second => 1.0 - ratio,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Top-level workspace managing tabs and panes for the code viewer.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CodeWorkspace {
|
||||
tabs: Vec<EditorTab>,
|
||||
active_tab: usize,
|
||||
next_tab_id: u64,
|
||||
next_pane_id: u64,
|
||||
}
|
||||
|
||||
const WORKSPACE_SNAPSHOT_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkspaceSnapshot {
|
||||
version: u32,
|
||||
active_tab: usize,
|
||||
next_tab_id: u64,
|
||||
next_pane_id: u64,
|
||||
tabs: Vec<TabSnapshot>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct TabSnapshot {
|
||||
id: u64,
|
||||
title: String,
|
||||
active: u64,
|
||||
root: LayoutNode,
|
||||
panes: Vec<PaneSnapshot>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct PaneSnapshot {
|
||||
id: u64,
|
||||
absolute_path: Option<String>,
|
||||
display_path: Option<String>,
|
||||
is_dirty: bool,
|
||||
is_staged: bool,
|
||||
scroll: ScrollSnapshot,
|
||||
viewport_height: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ScrollSnapshot {
|
||||
pub scroll: usize,
|
||||
pub stick_to_bottom: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PaneRestoreRequest {
|
||||
pub pane_id: PaneId,
|
||||
pub absolute_path: Option<PathBuf>,
|
||||
pub display_path: Option<String>,
|
||||
pub scroll: ScrollSnapshot,
|
||||
}
|
||||
|
||||
impl Default for CodeWorkspace {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CodeWorkspace {
|
||||
pub fn new() -> Self {
|
||||
let mut next_tab_id = 0;
|
||||
let mut next_pane_id = 0;
|
||||
let pane_id = PaneId::next(&mut next_pane_id);
|
||||
let first_pane = CodePane::new(pane_id);
|
||||
let tab_id = TabId::next(&mut next_tab_id);
|
||||
let title = format!("Tab {}", tab_id.0);
|
||||
let first_tab = EditorTab::new(tab_id, title, first_pane);
|
||||
Self {
|
||||
tabs: vec![first_tab],
|
||||
active_tab: 0,
|
||||
next_tab_id,
|
||||
next_pane_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tabs(&self) -> &[EditorTab] {
|
||||
&self.tabs
|
||||
}
|
||||
|
||||
pub fn tabs_mut(&mut self) -> &mut [EditorTab] {
|
||||
&mut self.tabs
|
||||
}
|
||||
|
||||
pub fn active_tab_index(&self) -> usize {
|
||||
self.active_tab.min(self.tabs.len().saturating_sub(1))
|
||||
}
|
||||
|
||||
pub fn active_tab(&self) -> Option<&EditorTab> {
|
||||
self.tabs.get(self.active_tab_index())
|
||||
}
|
||||
|
||||
pub fn active_tab_mut(&mut self) -> Option<&mut EditorTab> {
|
||||
let idx = self.active_tab_index();
|
||||
self.tabs.get_mut(idx)
|
||||
}
|
||||
|
||||
pub fn active_pane(&self) -> Option<&CodePane> {
|
||||
self.active_tab().and_then(|tab| tab.active_pane())
|
||||
}
|
||||
|
||||
pub fn panes(&self) -> impl Iterator<Item = &CodePane> + '_ {
|
||||
self.tabs.iter().flat_map(|tab| tab.panes.values())
|
||||
}
|
||||
|
||||
pub fn active_pane_mut(&mut self) -> Option<&mut CodePane> {
|
||||
self.active_tab_mut().and_then(|tab| tab.active_pane_mut())
|
||||
}
|
||||
|
||||
pub fn set_active_tab(&mut self, index: usize) {
|
||||
if index < self.tabs.len() {
|
||||
self.active_tab = index;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ensure_tab(&mut self) {
|
||||
if self.tabs.is_empty() {
|
||||
let mut next_tab_id = self.next_tab_id;
|
||||
let mut next_pane_id = self.next_pane_id;
|
||||
let pane_id = PaneId::next(&mut next_pane_id);
|
||||
let pane = CodePane::new(pane_id);
|
||||
let tab_id = TabId::next(&mut next_tab_id);
|
||||
let title = format!("Tab {}", tab_id.0);
|
||||
let tab = EditorTab::new(tab_id, title, pane);
|
||||
self.tabs.push(tab);
|
||||
self.active_tab = 0;
|
||||
self.next_tab_id = next_tab_id;
|
||||
self.next_pane_id = next_pane_id;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_contents(
|
||||
&mut self,
|
||||
absolute: Option<PathBuf>,
|
||||
display: Option<String>,
|
||||
lines: Vec<String>,
|
||||
) {
|
||||
self.ensure_tab();
|
||||
if let Some(tab) = self.active_tab_mut() {
|
||||
if let Some(pane) = tab.active_pane_mut() {
|
||||
pane.set_contents(absolute, display, lines);
|
||||
}
|
||||
tab.update_title_from_active();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_active_pane(&mut self) {
|
||||
if let Some(tab) = self.active_tab_mut() {
|
||||
if let Some(pane) = tab.active_pane_mut() {
|
||||
pane.clear();
|
||||
}
|
||||
tab.update_title_from_active();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_active_viewport_height(&mut self, height: usize) {
|
||||
if let Some(pane) = self.active_pane_mut() {
|
||||
pane.set_viewport_height(height);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn active_pane_id(&self) -> Option<PaneId> {
|
||||
self.active_tab().map(|tab| tab.active)
|
||||
}
|
||||
|
||||
pub fn split_active(&mut self, axis: SplitAxis) -> Option<PaneId> {
|
||||
self.ensure_tab();
|
||||
let active_id = self.active_tab()?.active;
|
||||
let new_pane_id = PaneId::next(&mut self.next_pane_id);
|
||||
let replacement = LayoutNode::Split {
|
||||
axis,
|
||||
ratio: 0.5,
|
||||
first: Box::new(LayoutNode::Leaf(active_id)),
|
||||
second: Box::new(LayoutNode::Leaf(new_pane_id)),
|
||||
};
|
||||
|
||||
self.active_tab_mut().and_then(|tab| {
|
||||
if tab.root.replace_leaf(active_id, replacement) {
|
||||
tab.panes.insert(new_pane_id, CodePane::new(new_pane_id));
|
||||
tab.active = new_pane_id;
|
||||
Some(new_pane_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn open_new_tab(&mut self) -> PaneId {
|
||||
let pane_id = PaneId::next(&mut self.next_pane_id);
|
||||
let pane = CodePane::new(pane_id);
|
||||
let tab_id = TabId::next(&mut self.next_tab_id);
|
||||
let title = format!("Tab {}", tab_id.0);
|
||||
let tab = EditorTab::new(tab_id, title, pane);
|
||||
self.tabs.push(tab);
|
||||
self.active_tab = self.tabs.len().saturating_sub(1);
|
||||
pane_id
|
||||
}
|
||||
|
||||
pub fn snapshot(&self) -> WorkspaceSnapshot {
|
||||
let tabs = self
|
||||
.tabs
|
||||
.iter()
|
||||
.map(|tab| {
|
||||
let panes = tab
|
||||
.panes
|
||||
.values()
|
||||
.map(|pane| PaneSnapshot {
|
||||
id: pane.id.raw(),
|
||||
absolute_path: pane
|
||||
.absolute_path
|
||||
.as_ref()
|
||||
.map(|p| p.to_string_lossy().into_owned()),
|
||||
display_path: pane.display_path.clone(),
|
||||
is_dirty: pane.is_dirty,
|
||||
is_staged: pane.is_staged,
|
||||
scroll: ScrollSnapshot {
|
||||
scroll: pane.scroll.scroll,
|
||||
stick_to_bottom: pane.scroll.stick_to_bottom,
|
||||
},
|
||||
viewport_height: pane.viewport_height,
|
||||
})
|
||||
.collect();
|
||||
|
||||
TabSnapshot {
|
||||
id: tab.id.raw(),
|
||||
title: tab.title.clone(),
|
||||
active: tab.active.raw(),
|
||||
root: tab.root.clone(),
|
||||
panes,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
WorkspaceSnapshot {
|
||||
version: WORKSPACE_SNAPSHOT_VERSION,
|
||||
active_tab: self.active_tab_index(),
|
||||
next_tab_id: self.next_tab_id,
|
||||
next_pane_id: self.next_pane_id,
|
||||
tabs,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apply_snapshot(&mut self, snapshot: WorkspaceSnapshot) -> Vec<PaneRestoreRequest> {
|
||||
if snapshot.version != WORKSPACE_SNAPSHOT_VERSION {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let mut restore_requests = Vec::new();
|
||||
let mut tabs = Vec::new();
|
||||
|
||||
for tab_snapshot in snapshot.tabs {
|
||||
let mut panes = HashMap::new();
|
||||
for pane_snapshot in tab_snapshot.panes {
|
||||
let pane_id = PaneId::from_raw(pane_snapshot.id);
|
||||
let mut pane = CodePane::new(pane_id);
|
||||
pane.absolute_path = pane_snapshot.absolute_path.as_ref().map(PathBuf::from);
|
||||
pane.display_path = pane_snapshot.display_path.clone();
|
||||
pane.is_dirty = pane_snapshot.is_dirty;
|
||||
pane.is_staged = pane_snapshot.is_staged;
|
||||
pane.scroll.scroll = pane_snapshot.scroll.scroll;
|
||||
pane.scroll.stick_to_bottom = pane_snapshot.scroll.stick_to_bottom;
|
||||
pane.viewport_height = pane_snapshot.viewport_height;
|
||||
pane.scroll.content_len = pane.lines.len();
|
||||
pane.title = pane
|
||||
.absolute_path
|
||||
.as_ref()
|
||||
.and_then(|p| p.file_name().map(|s| s.to_string_lossy().into_owned()))
|
||||
.or_else(|| pane.display_path.clone())
|
||||
.unwrap_or_else(|| "Untitled".to_string());
|
||||
panes.insert(pane_id, pane);
|
||||
|
||||
if pane_snapshot.absolute_path.is_some() {
|
||||
restore_requests.push(PaneRestoreRequest {
|
||||
pane_id,
|
||||
absolute_path: pane_snapshot.absolute_path.map(PathBuf::from),
|
||||
display_path: pane_snapshot.display_path.clone(),
|
||||
scroll: pane_snapshot.scroll.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if panes.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let tab_id = TabId::from_raw(tab_snapshot.id);
|
||||
let mut tab = EditorTab {
|
||||
id: tab_id,
|
||||
title: tab_snapshot.title,
|
||||
root: tab_snapshot.root,
|
||||
panes,
|
||||
active: PaneId::from_raw(tab_snapshot.active),
|
||||
};
|
||||
tab.update_title_from_active();
|
||||
tabs.push(tab);
|
||||
}
|
||||
|
||||
if tabs.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
self.tabs = tabs;
|
||||
self.active_tab = snapshot.active_tab.min(self.tabs.len().saturating_sub(1));
|
||||
self.next_tab_id = snapshot.next_tab_id;
|
||||
self.next_pane_id = snapshot.next_pane_id;
|
||||
|
||||
restore_requests
|
||||
}
|
||||
|
||||
pub fn move_focus(&mut self, direction: PaneDirection) -> bool {
|
||||
let active_index = self.active_tab_index();
|
||||
if let Some(tab) = self.tabs.get_mut(active_index) {
|
||||
tab.move_focus(direction)
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn resize_active_step(&mut self, direction: PaneDirection, amount: f32) -> Option<f32> {
|
||||
let active_index = self.active_tab_index();
|
||||
self.tabs
|
||||
.get_mut(active_index)
|
||||
.and_then(|tab| tab.resize_active_step(direction, amount))
|
||||
}
|
||||
|
||||
pub fn snap_active_share(
|
||||
&mut self,
|
||||
direction: PaneDirection,
|
||||
desired_share: f32,
|
||||
) -> Option<f32> {
|
||||
let active_index = self.active_tab_index();
|
||||
self.tabs
|
||||
.get_mut(active_index)
|
||||
.and_then(|tab| tab.snap_active_share(direction, desired_share))
|
||||
}
|
||||
|
||||
pub fn active_share(&self) -> Option<f32> {
|
||||
self.active_tab().and_then(|tab| tab.active_share())
|
||||
}
|
||||
|
||||
pub fn set_pane_contents(
|
||||
&mut self,
|
||||
pane_id: PaneId,
|
||||
absolute: Option<PathBuf>,
|
||||
display: Option<String>,
|
||||
lines: Vec<String>,
|
||||
) -> bool {
|
||||
for tab in &mut self.tabs {
|
||||
if let Some(pane) = tab.panes.get_mut(&pane_id) {
|
||||
pane.set_contents(absolute, display, lines);
|
||||
tab.update_title_from_active();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn restore_scroll(&mut self, pane_id: PaneId, snapshot: &ScrollSnapshot) -> bool {
|
||||
for tab in &mut self.tabs {
|
||||
if let Some(pane) = tab.panes.get_mut(&pane_id) {
|
||||
pane.scroll.scroll = snapshot.scroll;
|
||||
pane.scroll.stick_to_bottom = snapshot.stick_to_bottom;
|
||||
pane.scroll.content_len = pane.lines.len();
|
||||
return true;
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
}
|
||||
96
crates/owlen-tui/src/theme_util.rs
Normal file
96
crates/owlen-tui/src/theme_util.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
macro_rules! adjust_fields {
|
||||
($theme:expr, $func:expr, $($field:ident),+ $(,)?) => {
|
||||
$(
|
||||
$theme.$field = $func($theme.$field);
|
||||
)+
|
||||
};
|
||||
}
|
||||
|
||||
use owlen_core::theme::Theme;
|
||||
use ratatui::style::Color;
|
||||
|
||||
/// Return a clone of `base` with contrast adjustments applied.
|
||||
/// Positive `steps` increase contrast, negative values decrease it.
|
||||
pub fn with_contrast(base: &Theme, steps: i8) -> Theme {
|
||||
if steps == 0 {
|
||||
return base.clone();
|
||||
}
|
||||
|
||||
let factor = (1.0 + (steps as f32) * 0.18).clamp(0.3, 2.0);
|
||||
let adjust = |color: Color| adjust_color(color, factor);
|
||||
|
||||
let mut theme = base.clone();
|
||||
adjust_fields!(
|
||||
theme,
|
||||
adjust,
|
||||
text,
|
||||
background,
|
||||
focused_panel_border,
|
||||
unfocused_panel_border,
|
||||
focus_beacon_fg,
|
||||
focus_beacon_bg,
|
||||
unfocused_beacon_fg,
|
||||
pane_header_active,
|
||||
pane_header_inactive,
|
||||
pane_hint_text,
|
||||
user_message_role,
|
||||
assistant_message_role,
|
||||
tool_output,
|
||||
thinking_panel_title,
|
||||
command_bar_background,
|
||||
status_background,
|
||||
mode_normal,
|
||||
mode_editing,
|
||||
mode_model_selection,
|
||||
mode_provider_selection,
|
||||
mode_help,
|
||||
mode_visual,
|
||||
mode_command,
|
||||
selection_bg,
|
||||
selection_fg,
|
||||
cursor,
|
||||
code_block_background,
|
||||
code_block_border,
|
||||
code_block_text,
|
||||
code_block_keyword,
|
||||
code_block_string,
|
||||
code_block_comment,
|
||||
placeholder,
|
||||
error,
|
||||
info,
|
||||
agent_thought,
|
||||
agent_action,
|
||||
agent_action_input,
|
||||
agent_observation,
|
||||
agent_final_answer,
|
||||
agent_badge_running_fg,
|
||||
agent_badge_running_bg,
|
||||
agent_badge_idle_fg,
|
||||
agent_badge_idle_bg,
|
||||
operating_chat_fg,
|
||||
operating_chat_bg,
|
||||
operating_code_fg,
|
||||
operating_code_bg
|
||||
);
|
||||
|
||||
theme
|
||||
}
|
||||
|
||||
fn adjust_color(color: Color, factor: f32) -> Color {
|
||||
match color {
|
||||
Color::Rgb(r, g, b) => {
|
||||
let adjust_component = |component: u8| -> u8 {
|
||||
let normalized = component as f32 / 255.0;
|
||||
let contrasted = ((normalized - 0.5) * factor + 0.5).clamp(0.0, 1.0);
|
||||
(contrasted * 255.0).round().clamp(0.0, 255.0) as u8
|
||||
};
|
||||
|
||||
Color::Rgb(
|
||||
adjust_component(r),
|
||||
adjust_component(g),
|
||||
adjust_component(b),
|
||||
)
|
||||
}
|
||||
_ => color,
|
||||
}
|
||||
}
|
||||
114
crates/owlen-tui/src/toast.rs
Normal file
114
crates/owlen-tui/src/toast.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
use std::collections::VecDeque;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
/// Severity level for toast notifications.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToastLevel {
|
||||
Info,
|
||||
Success,
|
||||
Warning,
|
||||
Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Toast {
|
||||
pub message: String,
|
||||
pub level: ToastLevel,
|
||||
created: Instant,
|
||||
duration: Duration,
|
||||
}
|
||||
|
||||
impl Toast {
|
||||
fn new(message: String, level: ToastLevel, lifetime: Duration) -> Self {
|
||||
Self {
|
||||
message,
|
||||
level,
|
||||
created: Instant::now(),
|
||||
duration: lifetime,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_expired(&self, now: Instant) -> bool {
|
||||
now.duration_since(self.created) >= self.duration
|
||||
}
|
||||
}
|
||||
|
||||
/// Fixed-size toast queue with automatic expiration.
|
||||
#[derive(Debug)]
|
||||
pub struct ToastManager {
|
||||
items: VecDeque<Toast>,
|
||||
max_active: usize,
|
||||
lifetime: Duration,
|
||||
}
|
||||
|
||||
impl Default for ToastManager {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl ToastManager {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items: VecDeque::new(),
|
||||
max_active: 3,
|
||||
lifetime: Duration::from_secs(3),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_lifetime(mut self, duration: Duration) -> Self {
|
||||
self.lifetime = duration;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn push(&mut self, message: impl Into<String>, level: ToastLevel) {
|
||||
let toast = Toast::new(message.into(), level, self.lifetime);
|
||||
self.items.push_front(toast);
|
||||
while self.items.len() > self.max_active {
|
||||
self.items.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn retain_active(&mut self) {
|
||||
let now = Instant::now();
|
||||
self.items.retain(|toast| !toast.is_expired(now));
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &Toast> {
|
||||
self.items.iter()
|
||||
}
|
||||
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.items.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::thread::sleep;
|
||||
|
||||
#[test]
|
||||
fn manager_limits_active_toasts() {
|
||||
let mut manager = ToastManager::new();
|
||||
manager.push("first", ToastLevel::Info);
|
||||
manager.push("second", ToastLevel::Warning);
|
||||
manager.push("third", ToastLevel::Success);
|
||||
manager.push("fourth", ToastLevel::Error);
|
||||
|
||||
let collected: Vec<_> = manager.iter().map(|toast| toast.message.clone()).collect();
|
||||
assert_eq!(collected.len(), 3);
|
||||
assert_eq!(collected[0], "fourth");
|
||||
assert_eq!(collected[2], "second");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn manager_expires_toasts_after_lifetime() {
|
||||
let mut manager = ToastManager::new().with_lifetime(Duration::from_millis(1));
|
||||
manager.push("short lived", ToastLevel::Info);
|
||||
assert!(!manager.is_empty());
|
||||
sleep(Duration::from_millis(5));
|
||||
manager.retain_active();
|
||||
assert!(manager.is_empty());
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
3
crates/owlen-tui/src/widgets/mod.rs
Normal file
3
crates/owlen-tui/src/widgets/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
//! Reusable widgets composed specifically for the Owlen TUI.
|
||||
|
||||
pub mod model_picker;
|
||||
864
crates/owlen-tui/src/widgets/model_picker.rs
Normal file
864
crates/owlen-tui/src/widgets/model_picker.rs
Normal file
@@ -0,0 +1,864 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use owlen_core::provider::{AnnotatedModelInfo, ProviderStatus, ProviderType};
|
||||
use owlen_core::types::ModelInfo;
|
||||
use ratatui::{
|
||||
Frame,
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, List, ListItem, ListState, Paragraph},
|
||||
};
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
|
||||
use crate::chat_app::{
|
||||
ChatApp, HighlightMask, ModelAvailabilityState, ModelScope, ModelSearchInfo,
|
||||
ModelSelectorItemKind,
|
||||
};
|
||||
|
||||
/// Filtering modes for the model picker popup.
|
||||
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum FilterMode {
|
||||
#[default]
|
||||
All,
|
||||
LocalOnly,
|
||||
CloudOnly,
|
||||
Available,
|
||||
}
|
||||
|
||||
pub fn render_model_picker(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
let theme = app.theme();
|
||||
let area = frame.area();
|
||||
if area.width == 0 || area.height == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let selector_items = app.model_selector_items();
|
||||
if selector_items.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let search_query = app.model_search_query().trim().to_string();
|
||||
let search_active = !search_query.is_empty();
|
||||
|
||||
let max_width = area.width.min(90);
|
||||
let min_width = area.width.min(56);
|
||||
let width = area.width.min(max_width).max(min_width).max(1);
|
||||
|
||||
let visible_models = app.visible_model_count();
|
||||
let min_rows: usize = if search_active { 5 } else { 4 };
|
||||
let max_rows: usize = 12;
|
||||
let row_estimate = visible_models.max(min_rows).min(max_rows);
|
||||
let mut height = (row_estimate as u16) * 3 + 8;
|
||||
let min_height = area.height.clamp(8, 12);
|
||||
let max_height = area.height.min(32);
|
||||
height = height.clamp(min_height, max_height);
|
||||
|
||||
let x = area.x + (area.width.saturating_sub(width)) / 2;
|
||||
let mut y = area.y + (area.height.saturating_sub(height)) / 3;
|
||||
if y < area.y {
|
||||
y = area.y;
|
||||
}
|
||||
|
||||
let popup_area = Rect::new(x, y, width, height);
|
||||
frame.render_widget(Clear, popup_area);
|
||||
|
||||
let mut title_spans = vec![
|
||||
Span::styled(
|
||||
" Model Selector ",
|
||||
Style::default().fg(theme.info).add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::styled(
|
||||
format!("· Provider: {}", app.selected_provider),
|
||||
Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM),
|
||||
),
|
||||
];
|
||||
if app.model_filter_mode() != FilterMode::All {
|
||||
title_spans.push(Span::raw(" "));
|
||||
title_spans.push(filter_badge(app.model_filter_mode(), theme));
|
||||
}
|
||||
|
||||
let block = Block::default()
|
||||
.title(Line::from(title_spans))
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(theme.info))
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
|
||||
let inner = block.inner(popup_area);
|
||||
frame.render_widget(block, popup_area);
|
||||
if inner.width == 0 || inner.height == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let layout = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(3),
|
||||
Constraint::Min(4),
|
||||
Constraint::Length(2),
|
||||
])
|
||||
.split(inner);
|
||||
|
||||
let matches = app.visible_model_count();
|
||||
let search_prefix = Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM);
|
||||
let bracket_style = Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM);
|
||||
let caret_style = if search_active {
|
||||
Style::default()
|
||||
.fg(theme.selection_fg)
|
||||
.add_modifier(Modifier::BOLD)
|
||||
} else {
|
||||
Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM)
|
||||
};
|
||||
|
||||
let mut search_spans = Vec::new();
|
||||
search_spans.push(Span::styled("Search ▸ ", search_prefix));
|
||||
search_spans.push(Span::styled("[", bracket_style));
|
||||
search_spans.push(Span::styled(" ", bracket_style));
|
||||
|
||||
if search_active {
|
||||
search_spans.push(Span::styled(
|
||||
search_query.clone(),
|
||||
Style::default()
|
||||
.fg(theme.selection_fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
} else {
|
||||
search_spans.push(Span::styled(
|
||||
"Type to search…",
|
||||
Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM),
|
||||
));
|
||||
}
|
||||
|
||||
search_spans.push(Span::styled(" ", bracket_style));
|
||||
search_spans.push(Span::styled("▎", caret_style));
|
||||
search_spans.push(Span::styled(" ", bracket_style));
|
||||
search_spans.push(Span::styled("]", bracket_style));
|
||||
search_spans.push(Span::raw(" "));
|
||||
let suffix_label = if search_active { "match" } else { "model" };
|
||||
search_spans.push(Span::styled(
|
||||
format!(
|
||||
"({} {}{})",
|
||||
matches,
|
||||
suffix_label,
|
||||
if matches == 1 { "" } else { "s" }
|
||||
),
|
||||
Style::default().fg(theme.placeholder),
|
||||
));
|
||||
|
||||
let search_line = Line::from(search_spans);
|
||||
|
||||
let instruction_line = if search_active {
|
||||
Line::from(vec![
|
||||
Span::styled("Backspace", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": delete "),
|
||||
Span::styled("Ctrl+U", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": clear "),
|
||||
Span::styled("Enter", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": select "),
|
||||
Span::styled("Esc", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": close"),
|
||||
])
|
||||
} else {
|
||||
Line::from(vec![
|
||||
Span::styled("Enter", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": select "),
|
||||
Span::styled("Space", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": toggle provider "),
|
||||
Span::styled("Esc", Style::default().fg(theme.placeholder)),
|
||||
Span::raw(": close"),
|
||||
])
|
||||
};
|
||||
|
||||
let search_paragraph = Paragraph::new(vec![search_line, instruction_line])
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
frame.render_widget(search_paragraph, layout[0]);
|
||||
|
||||
let highlight_style = Style::default()
|
||||
.fg(theme.selection_fg)
|
||||
.bg(theme.selection_bg)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
|
||||
let highlight_symbol = " ";
|
||||
let highlight_width = UnicodeWidthStr::width(highlight_symbol);
|
||||
let max_line_width = layout[1]
|
||||
.width
|
||||
.saturating_sub(highlight_width as u16)
|
||||
.max(1) as usize;
|
||||
|
||||
let active_model_id = app.selected_model();
|
||||
let annotated = app.annotated_models();
|
||||
|
||||
let mut items: Vec<ListItem> = Vec::new();
|
||||
for item in selector_items.iter() {
|
||||
match item.kind() {
|
||||
ModelSelectorItemKind::Header {
|
||||
provider,
|
||||
expanded,
|
||||
status,
|
||||
provider_type,
|
||||
} => {
|
||||
let mut spans = Vec::new();
|
||||
spans.push(status_icon(*status, theme));
|
||||
spans.push(Span::raw(" "));
|
||||
let header_spans = render_highlighted_text(
|
||||
provider,
|
||||
if search_active {
|
||||
app.provider_search_highlight(provider)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
Style::default()
|
||||
.fg(theme.mode_command)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
highlight_style,
|
||||
);
|
||||
spans.extend(header_spans);
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(provider_type_badge(*provider_type, theme));
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled(
|
||||
if *expanded { "▼" } else { "▶" },
|
||||
Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM),
|
||||
));
|
||||
|
||||
let line = clip_line_to_width(Line::from(spans), max_line_width);
|
||||
items.push(ListItem::new(vec![line]).style(Style::default().bg(theme.background)));
|
||||
}
|
||||
ModelSelectorItemKind::Scope { label, status, .. } => {
|
||||
let (style, icon) = scope_status_style(*status, theme);
|
||||
let line = clip_line_to_width(
|
||||
Line::from(vec![
|
||||
Span::styled(icon, style),
|
||||
Span::raw(" "),
|
||||
Span::styled(label.clone(), style),
|
||||
]),
|
||||
max_line_width,
|
||||
);
|
||||
items.push(ListItem::new(vec![line]).style(Style::default().bg(theme.background)));
|
||||
}
|
||||
ModelSelectorItemKind::Model { model_index, .. } => {
|
||||
let mut lines: Vec<Line<'static>> = Vec::new();
|
||||
if let Some(model) = app.model_info_by_index(*model_index) {
|
||||
let badges = model_badge_icons(model);
|
||||
let detail = app.cached_model_detail(&model.id);
|
||||
let annotated_model = annotated.get(*model_index);
|
||||
let search_info = if search_active {
|
||||
app.model_search_info(*model_index)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (title, metadata) = build_model_selector_lines(
|
||||
theme,
|
||||
model,
|
||||
annotated_model,
|
||||
&badges,
|
||||
detail,
|
||||
model.id == active_model_id,
|
||||
SearchRenderContext {
|
||||
info: search_info,
|
||||
highlight_style,
|
||||
},
|
||||
);
|
||||
lines.push(clip_line_to_width(title, max_line_width));
|
||||
if let Some(meta) = metadata {
|
||||
lines.push(clip_line_to_width(meta, max_line_width));
|
||||
}
|
||||
} else {
|
||||
lines.push(clip_line_to_width(
|
||||
Line::from(Span::styled(
|
||||
" <model unavailable>",
|
||||
Style::default().fg(theme.error),
|
||||
)),
|
||||
max_line_width,
|
||||
));
|
||||
}
|
||||
items.push(ListItem::new(lines).style(Style::default().bg(theme.background)));
|
||||
}
|
||||
ModelSelectorItemKind::Empty {
|
||||
message, status, ..
|
||||
} => {
|
||||
let (style, icon) = empty_status_style(*status, theme);
|
||||
let msg = message
|
||||
.as_ref()
|
||||
.map(|msg| msg.as_str())
|
||||
.unwrap_or("(no models configured)");
|
||||
let mut spans = vec![Span::styled(icon, style), Span::raw(" ")];
|
||||
spans.push(Span::styled(format!(" {}", msg), style));
|
||||
let line = clip_line_to_width(Line::from(spans), max_line_width);
|
||||
items.push(ListItem::new(vec![line]).style(Style::default().bg(theme.background)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let list = List::new(items)
|
||||
.highlight_style(
|
||||
Style::default()
|
||||
.bg(theme.selection_bg)
|
||||
.fg(theme.selection_fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
)
|
||||
.highlight_symbol(" ");
|
||||
|
||||
let mut state = ListState::default();
|
||||
state.select(app.selected_model_item());
|
||||
frame.render_stateful_widget(list, layout[1], &mut state);
|
||||
|
||||
let footer_text = if search_active {
|
||||
"Enter: select · Space: toggle provider · Backspace: delete · Ctrl+U: clear"
|
||||
} else {
|
||||
"Enter: select · Space: toggle provider · Type to search · Esc: cancel"
|
||||
};
|
||||
|
||||
let footer = Paragraph::new(Line::from(Span::styled(
|
||||
footer_text,
|
||||
Style::default().fg(theme.placeholder),
|
||||
)))
|
||||
.alignment(ratatui::layout::Alignment::Center)
|
||||
.style(Style::default().bg(theme.background).fg(theme.placeholder));
|
||||
frame.render_widget(footer, layout[2]);
|
||||
}
|
||||
|
||||
fn status_icon(status: ProviderStatus, theme: &owlen_core::theme::Theme) -> Span<'static> {
|
||||
let (symbol, color) = match status {
|
||||
ProviderStatus::Available => ("✓", theme.info),
|
||||
ProviderStatus::Unavailable => ("✗", theme.error),
|
||||
ProviderStatus::RequiresSetup => ("⚙", Color::Yellow),
|
||||
};
|
||||
Span::styled(
|
||||
symbol,
|
||||
Style::default().fg(color).add_modifier(Modifier::BOLD),
|
||||
)
|
||||
}
|
||||
|
||||
fn provider_type_badge(
|
||||
provider_type: ProviderType,
|
||||
theme: &owlen_core::theme::Theme,
|
||||
) -> Span<'static> {
|
||||
let (label, color) = match provider_type {
|
||||
ProviderType::Local => ("[Local]", theme.mode_normal),
|
||||
ProviderType::Cloud => ("[Cloud]", theme.mode_help),
|
||||
};
|
||||
Span::styled(
|
||||
label,
|
||||
Style::default().fg(color).add_modifier(Modifier::BOLD),
|
||||
)
|
||||
}
|
||||
|
||||
fn scope_status_style(
|
||||
status: ModelAvailabilityState,
|
||||
theme: &owlen_core::theme::Theme,
|
||||
) -> (Style, &'static str) {
|
||||
match status {
|
||||
ModelAvailabilityState::Available => (
|
||||
Style::default().fg(theme.info).add_modifier(Modifier::BOLD),
|
||||
"✓",
|
||||
),
|
||||
ModelAvailabilityState::Unavailable => (
|
||||
Style::default()
|
||||
.fg(theme.error)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
"✗",
|
||||
),
|
||||
ModelAvailabilityState::Unknown => (
|
||||
Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
"⚙",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_status_style(
|
||||
status: Option<ModelAvailabilityState>,
|
||||
theme: &owlen_core::theme::Theme,
|
||||
) -> (Style, &'static str) {
|
||||
match status.unwrap_or(ModelAvailabilityState::Unknown) {
|
||||
ModelAvailabilityState::Available => (
|
||||
Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM),
|
||||
"•",
|
||||
),
|
||||
ModelAvailabilityState::Unavailable => (
|
||||
Style::default()
|
||||
.fg(theme.error)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
"✗",
|
||||
),
|
||||
ModelAvailabilityState::Unknown => (
|
||||
Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
"⚙",
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn filter_badge(mode: FilterMode, theme: &owlen_core::theme::Theme) -> Span<'static> {
|
||||
let label = match mode {
|
||||
FilterMode::All => return Span::raw(""),
|
||||
FilterMode::LocalOnly => "Local",
|
||||
FilterMode::CloudOnly => "Cloud",
|
||||
FilterMode::Available => "Available",
|
||||
};
|
||||
Span::styled(
|
||||
format!("[{label}]"),
|
||||
Style::default()
|
||||
.fg(theme.mode_provider_selection)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_highlighted_text(
|
||||
text: &str,
|
||||
highlight: Option<&HighlightMask>,
|
||||
normal_style: Style,
|
||||
highlight_style: Style,
|
||||
) -> Vec<Span<'static>> {
|
||||
if text.is_empty() {
|
||||
return Vec::new();
|
||||
}
|
||||
|
||||
let graphemes: Vec<&str> = UnicodeSegmentation::graphemes(text, true).collect();
|
||||
let mask = highlight.map(|mask| mask.bits()).unwrap_or(&[]);
|
||||
|
||||
let mut spans: Vec<Span<'static>> = Vec::new();
|
||||
let mut buffer = String::new();
|
||||
let mut current_highlight = false;
|
||||
|
||||
for (idx, grapheme) in graphemes.iter().enumerate() {
|
||||
let mark = mask.get(idx).copied().unwrap_or(false);
|
||||
if idx == 0 {
|
||||
current_highlight = mark;
|
||||
}
|
||||
if mark != current_highlight {
|
||||
if !buffer.is_empty() {
|
||||
let style = if current_highlight {
|
||||
highlight_style
|
||||
} else {
|
||||
normal_style
|
||||
};
|
||||
spans.push(Span::styled(buffer.clone(), style));
|
||||
buffer.clear();
|
||||
}
|
||||
current_highlight = mark;
|
||||
}
|
||||
buffer.push_str(grapheme);
|
||||
}
|
||||
|
||||
if !buffer.is_empty() {
|
||||
let style = if current_highlight {
|
||||
highlight_style
|
||||
} else {
|
||||
normal_style
|
||||
};
|
||||
spans.push(Span::styled(buffer, style));
|
||||
}
|
||||
|
||||
if spans.is_empty() {
|
||||
spans.push(Span::styled(text.to_string(), normal_style));
|
||||
}
|
||||
|
||||
spans
|
||||
}
|
||||
|
||||
struct SearchRenderContext<'a> {
|
||||
info: Option<&'a ModelSearchInfo>,
|
||||
highlight_style: Style,
|
||||
}
|
||||
|
||||
fn build_model_selector_lines<'a>(
|
||||
theme: &owlen_core::theme::Theme,
|
||||
model: &'a ModelInfo,
|
||||
annotated: Option<&'a AnnotatedModelInfo>,
|
||||
badges: &[&'static str],
|
||||
detail: Option<&'a owlen_core::model::DetailedModelInfo>,
|
||||
is_current: bool,
|
||||
search: SearchRenderContext<'a>,
|
||||
) -> (Line<'static>, Option<Line<'static>>) {
|
||||
let provider_type = annotated
|
||||
.map(|info| info.model.provider.provider_type)
|
||||
.unwrap_or_else(|| match ChatApp::model_scope_from_capabilities(model) {
|
||||
ModelScope::Cloud => ProviderType::Cloud,
|
||||
ModelScope::Local => ProviderType::Local,
|
||||
ModelScope::Other(_) => {
|
||||
if model.provider.to_ascii_lowercase().contains("cloud") {
|
||||
ProviderType::Cloud
|
||||
} else {
|
||||
ProviderType::Local
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mut spans: Vec<Span<'static>> = Vec::new();
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(provider_type_badge(provider_type, theme));
|
||||
spans.push(Span::raw(" "));
|
||||
|
||||
let name_style = Style::default().fg(theme.text).add_modifier(Modifier::BOLD);
|
||||
let id_style = Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM);
|
||||
|
||||
let name_trimmed = model.name.trim();
|
||||
if !name_trimmed.is_empty() {
|
||||
let name_spans = render_highlighted_text(
|
||||
name_trimmed,
|
||||
search.info.and_then(|info| info.name.as_ref()),
|
||||
name_style,
|
||||
search.highlight_style,
|
||||
);
|
||||
spans.extend(name_spans);
|
||||
|
||||
if !model.id.eq_ignore_ascii_case(name_trimmed) {
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled("·", Style::default().fg(theme.placeholder)));
|
||||
spans.push(Span::raw(" "));
|
||||
let id_spans = render_highlighted_text(
|
||||
model.id.as_str(),
|
||||
search.info.and_then(|info| info.id.as_ref()),
|
||||
id_style,
|
||||
search.highlight_style,
|
||||
);
|
||||
spans.extend(id_spans);
|
||||
}
|
||||
} else {
|
||||
let id_spans = render_highlighted_text(
|
||||
model.id.as_str(),
|
||||
search.info.and_then(|info| info.id.as_ref()),
|
||||
name_style,
|
||||
search.highlight_style,
|
||||
);
|
||||
spans.extend(id_spans);
|
||||
}
|
||||
|
||||
if !badges.is_empty() {
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled(
|
||||
badges.join(" "),
|
||||
Style::default().fg(theme.placeholder),
|
||||
));
|
||||
}
|
||||
|
||||
if is_current {
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled(
|
||||
"✓",
|
||||
Style::default().fg(theme.info).add_modifier(Modifier::BOLD),
|
||||
));
|
||||
}
|
||||
|
||||
let mut meta_tags: Vec<String> = Vec::new();
|
||||
let mut seen_meta: HashSet<String> = HashSet::new();
|
||||
let mut push_meta = |value: String| {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
return;
|
||||
}
|
||||
let key = trimmed.to_ascii_lowercase();
|
||||
if seen_meta.insert(key) {
|
||||
meta_tags.push(trimmed.to_string());
|
||||
}
|
||||
};
|
||||
|
||||
let scope = ChatApp::model_scope_from_capabilities(model);
|
||||
let scope_label = ChatApp::scope_display_name(&scope);
|
||||
if !scope_label.eq_ignore_ascii_case("unknown") {
|
||||
push_meta(scope_label.clone());
|
||||
}
|
||||
|
||||
if let Some(detail) = detail {
|
||||
if let Some(ctx) = detail.context_length {
|
||||
push_meta(format!("max tokens {}", ctx));
|
||||
} else if let Some(ctx) = model.context_window {
|
||||
push_meta(format!("max tokens {}", ctx));
|
||||
}
|
||||
|
||||
if let Some(parameters) = detail
|
||||
.parameter_size
|
||||
.as_ref()
|
||||
.or(detail.parameters.as_ref())
|
||||
&& !parameters.trim().is_empty()
|
||||
{
|
||||
push_meta(parameters.trim().to_string());
|
||||
}
|
||||
|
||||
if let Some(arch) = detail.architecture.as_deref() {
|
||||
let trimmed = arch.trim();
|
||||
if !trimmed.is_empty() {
|
||||
push_meta(format!("arch {}", trimmed));
|
||||
}
|
||||
} else if let Some(family) = detail.family.as_deref() {
|
||||
let trimmed = family.trim();
|
||||
if !trimmed.is_empty() {
|
||||
push_meta(format!("family {}", trimmed));
|
||||
}
|
||||
} else if !detail.families.is_empty() {
|
||||
let families = detail
|
||||
.families
|
||||
.iter()
|
||||
.map(|f| f.trim())
|
||||
.filter(|f| !f.is_empty())
|
||||
.take(2)
|
||||
.collect::<Vec<_>>()
|
||||
.join("/");
|
||||
if !families.is_empty() {
|
||||
push_meta(format!("family {}", families));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(embedding) = detail.embedding_length {
|
||||
push_meta(format!("embedding {}", embedding));
|
||||
}
|
||||
|
||||
if let Some(size) = detail.size {
|
||||
push_meta(format_short_size(size));
|
||||
}
|
||||
|
||||
if let Some(quant) = detail
|
||||
.quantization
|
||||
.as_ref()
|
||||
.filter(|q| !q.trim().is_empty())
|
||||
{
|
||||
push_meta(format!("quant {}", quant.trim()));
|
||||
}
|
||||
} else if let Some(ctx) = model.context_window {
|
||||
push_meta(format!("max tokens {}", ctx));
|
||||
}
|
||||
|
||||
let mut description_segment: Option<(String, Option<HighlightMask>)> = None;
|
||||
if let Some(desc) = model.description.as_deref() {
|
||||
let trimmed = desc.trim();
|
||||
if !trimmed.is_empty() {
|
||||
let (display, retained, truncated) = ellipsize(trimmed, 80);
|
||||
let highlight = search
|
||||
.info
|
||||
.and_then(|info| info.description.as_ref())
|
||||
.filter(|mask| mask.is_marked())
|
||||
.map(|mask| {
|
||||
if truncated {
|
||||
mask.truncated(retained)
|
||||
} else {
|
||||
mask.clone()
|
||||
}
|
||||
});
|
||||
description_segment = Some((display, highlight));
|
||||
}
|
||||
}
|
||||
|
||||
let metadata = if meta_tags.is_empty() && description_segment.is_none() {
|
||||
None
|
||||
} else {
|
||||
let meta_style = Style::default()
|
||||
.fg(theme.placeholder)
|
||||
.add_modifier(Modifier::DIM);
|
||||
let mut segments: Vec<Span<'static>> = Vec::new();
|
||||
segments.push(Span::styled(" ", meta_style));
|
||||
|
||||
let mut first = true;
|
||||
for tag in meta_tags {
|
||||
if !first {
|
||||
segments.push(Span::styled(" • ", meta_style));
|
||||
}
|
||||
segments.push(Span::styled(tag, meta_style));
|
||||
first = false;
|
||||
}
|
||||
|
||||
if let Some((text, highlight)) = description_segment {
|
||||
if !first {
|
||||
segments.push(Span::styled(" • ", meta_style));
|
||||
}
|
||||
if let Some(mask) = highlight.as_ref() {
|
||||
let desc_spans = render_highlighted_text(
|
||||
text.as_str(),
|
||||
Some(mask),
|
||||
meta_style,
|
||||
search.highlight_style,
|
||||
);
|
||||
segments.extend(desc_spans);
|
||||
} else {
|
||||
segments.push(Span::styled(text, meta_style));
|
||||
}
|
||||
}
|
||||
|
||||
Some(Line::from(segments))
|
||||
};
|
||||
|
||||
(Line::from(spans), metadata)
|
||||
}
|
||||
|
||||
fn clip_line_to_width(line: Line<'_>, max_width: usize) -> Line<'static> {
|
||||
if max_width == 0 {
|
||||
return Line::from(Vec::<Span<'static>>::new());
|
||||
}
|
||||
|
||||
let mut used = 0usize;
|
||||
let mut clipped: Vec<Span<'static>> = Vec::new();
|
||||
|
||||
for span in line.spans {
|
||||
if used >= max_width {
|
||||
break;
|
||||
}
|
||||
let text = span.content.to_string();
|
||||
let span_width = UnicodeWidthStr::width(text.as_str());
|
||||
if used + span_width <= max_width {
|
||||
if !text.is_empty() {
|
||||
clipped.push(Span::styled(text, span.style));
|
||||
}
|
||||
used += span_width;
|
||||
} else {
|
||||
let mut buf = String::new();
|
||||
for grapheme in span.content.as_ref().graphemes(true) {
|
||||
let g_width = UnicodeWidthStr::width(grapheme);
|
||||
if g_width == 0 {
|
||||
buf.push_str(grapheme);
|
||||
continue;
|
||||
}
|
||||
if used + g_width > max_width {
|
||||
break;
|
||||
}
|
||||
buf.push_str(grapheme);
|
||||
used += g_width;
|
||||
}
|
||||
if !buf.is_empty() {
|
||||
clipped.push(Span::styled(buf, span.style));
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Line::from(clipped)
|
||||
}
|
||||
|
||||
fn ellipsize(text: &str, max_graphemes: usize) -> (String, usize, bool) {
|
||||
let graphemes: Vec<&str> = UnicodeSegmentation::graphemes(text, true).collect();
|
||||
if graphemes.len() <= max_graphemes {
|
||||
return (text.to_string(), graphemes.len(), false);
|
||||
}
|
||||
|
||||
let keep = max_graphemes.saturating_sub(1).max(1);
|
||||
let mut truncated = String::new();
|
||||
for grapheme in graphemes.iter().take(keep) {
|
||||
truncated.push_str(grapheme);
|
||||
}
|
||||
truncated.push('…');
|
||||
(truncated, keep, true)
|
||||
}
|
||||
|
||||
fn model_badge_icons(model: &ModelInfo) -> Vec<&'static str> {
|
||||
let mut badges = Vec::new();
|
||||
|
||||
if model.supports_tools {
|
||||
badges.push("🔧");
|
||||
}
|
||||
|
||||
if model_has_feature(model, &["think", "reason"]) {
|
||||
badges.push("🧠");
|
||||
}
|
||||
|
||||
if model_has_feature(model, &["vision", "multimodal", "image"]) {
|
||||
badges.push("👁️");
|
||||
}
|
||||
|
||||
if model_has_feature(model, &["audio", "speech", "voice"]) {
|
||||
badges.push("🎧");
|
||||
}
|
||||
|
||||
badges
|
||||
}
|
||||
|
||||
fn model_has_feature(model: &ModelInfo, keywords: &[&str]) -> bool {
|
||||
let name_lower = model.name.to_ascii_lowercase();
|
||||
if keywords.iter().any(|kw| name_lower.contains(kw)) {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(description) = &model.description {
|
||||
let description_lower = description.to_ascii_lowercase();
|
||||
if keywords.iter().any(|kw| description_lower.contains(kw)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if model.capabilities.iter().any(|cap| {
|
||||
let lc = cap.to_ascii_lowercase();
|
||||
keywords.iter().any(|kw| lc.contains(kw))
|
||||
}) {
|
||||
return true;
|
||||
}
|
||||
|
||||
keywords
|
||||
.iter()
|
||||
.any(|kw| model.provider.to_ascii_lowercase().contains(kw))
|
||||
}
|
||||
|
||||
fn format_short_size(bytes: u64) -> String {
|
||||
if bytes >= 1_000_000_000 {
|
||||
format!("{:.1} GB", bytes as f64 / 1_000_000_000_f64)
|
||||
} else if bytes >= 1_000_000 {
|
||||
format!("{:.1} MB", bytes as f64 / 1_000_000_f64)
|
||||
} else if bytes >= 1_000 {
|
||||
format!("{:.1} KB", bytes as f64 / 1_000_f64)
|
||||
} else {
|
||||
format!("{} B", bytes)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use owlen_core::types::ModelInfo;
|
||||
|
||||
fn model_with(capabilities: Vec<&str>, description: Option<&str>) -> ModelInfo {
|
||||
ModelInfo {
|
||||
id: "model".into(),
|
||||
name: "model".into(),
|
||||
description: description.map(|s| s.to_string()),
|
||||
provider: "test".into(),
|
||||
context_window: None,
|
||||
capabilities: capabilities.into_iter().map(|s| s.to_string()).collect(),
|
||||
supports_tools: false,
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_badges_recognize_thinking_capability() {
|
||||
let model = model_with(vec!["think"], None);
|
||||
assert!(model_badge_icons(&model).contains(&"🧠"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_badges_detect_tool_support() {
|
||||
let mut model = model_with(vec![], None);
|
||||
model.supports_tools = true;
|
||||
let icons = model_badge_icons(&model);
|
||||
assert!(icons.contains(&"🔧"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_badges_detect_vision_capability() {
|
||||
let model = model_with(vec![], Some("Supports vision tasks"));
|
||||
let icons = model_badge_icons(&model);
|
||||
assert!(icons.contains(&"👁️"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_badges_detect_audio_capability() {
|
||||
let model = model_with(vec!["audio"], None);
|
||||
let icons = model_badge_icons(&model);
|
||||
assert!(icons.contains(&"🎧"));
|
||||
}
|
||||
}
|
||||
164
crates/owlen-tui/tests/agent_flow_ui.rs
Normal file
164
crates/owlen-tui/tests/agent_flow_ui.rs
Normal file
@@ -0,0 +1,164 @@
|
||||
use std::{any::Any, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use futures_util::stream;
|
||||
use owlen_core::{
|
||||
Config, Mode, Provider,
|
||||
config::McpMode,
|
||||
session::SessionController,
|
||||
storage::StorageManager,
|
||||
types::{ChatResponse, Message, Role, ToolCall},
|
||||
ui::{NoOpUiController, UiController},
|
||||
};
|
||||
use owlen_tui::ChatApp;
|
||||
use owlen_tui::app::UiRuntime;
|
||||
use owlen_tui::events::Event;
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
struct StubProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for StubProvider {
|
||||
fn name(&self) -> &str {
|
||||
"stub-provider"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> owlen_core::Result<Vec<owlen_core::types::ModelInfo>> {
|
||||
Ok(vec![owlen_core::types::ModelInfo {
|
||||
id: "stub-model".into(),
|
||||
name: "Stub Model".into(),
|
||||
description: Some("Stub model for testing".into()),
|
||||
provider: self.name().into(),
|
||||
context_window: Some(4096),
|
||||
capabilities: vec!["chat".into()],
|
||||
supports_tools: true,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(
|
||||
&self,
|
||||
_request: owlen_core::types::ChatRequest,
|
||||
) -> owlen_core::Result<ChatResponse> {
|
||||
Ok(ChatResponse {
|
||||
message: Message::assistant("stub response".to_string()),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
})
|
||||
}
|
||||
|
||||
async fn stream_prompt(
|
||||
&self,
|
||||
_request: owlen_core::types::ChatRequest,
|
||||
) -> owlen_core::Result<owlen_core::ChatStream> {
|
||||
Ok(Box::pin(stream::empty()))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> owlen_core::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn denied_consent_appends_apology_message() {
|
||||
let temp_dir = tempdir().expect("temp dir");
|
||||
let storage = Arc::new(
|
||||
StorageManager::with_database_path(temp_dir.path().join("owlen-tui-tests.db"))
|
||||
.await
|
||||
.expect("storage"),
|
||||
);
|
||||
|
||||
let mut config = Config::default();
|
||||
config.privacy.encrypt_local_data = false;
|
||||
config.general.default_model = Some("stub-model".into());
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config
|
||||
.refresh_mcp_servers(None)
|
||||
.expect("refresh MCP servers");
|
||||
|
||||
let provider: Arc<dyn Provider> = Arc::new(StubProvider);
|
||||
let ui: Arc<dyn UiController> = Arc::new(NoOpUiController);
|
||||
let (event_tx, controller_event_rx) = mpsc::unbounded_channel();
|
||||
|
||||
// Pre-populate a pending consent request before handing the controller to the TUI.
|
||||
let mut session = SessionController::new(
|
||||
Arc::clone(&provider),
|
||||
config,
|
||||
Arc::clone(&storage),
|
||||
Arc::clone(&ui),
|
||||
true,
|
||||
Some(event_tx.clone()),
|
||||
)
|
||||
.await
|
||||
.expect("session controller");
|
||||
|
||||
session
|
||||
.set_operating_mode(Mode::Code)
|
||||
.await
|
||||
.expect("code mode");
|
||||
|
||||
let tool_call = ToolCall {
|
||||
id: "call-1".to_string(),
|
||||
name: "resources/delete".to_string(),
|
||||
arguments: serde_json::json!({"path": "/tmp/example.txt"}),
|
||||
};
|
||||
|
||||
let message_id = session
|
||||
.conversation_mut()
|
||||
.push_assistant_message("Preparing to modify files.");
|
||||
session
|
||||
.conversation_mut()
|
||||
.set_tool_calls_on_message(message_id, vec![tool_call])
|
||||
.expect("tool calls");
|
||||
|
||||
let advertised_calls = session
|
||||
.check_streaming_tool_calls(message_id)
|
||||
.expect("queued consent");
|
||||
assert_eq!(advertised_calls.len(), 1);
|
||||
|
||||
let (mut app, mut session_rx) = ChatApp::new(session, controller_event_rx)
|
||||
.await
|
||||
.expect("chat app");
|
||||
// Session events are not used in this test.
|
||||
session_rx.close();
|
||||
|
||||
// Process the controller event emitted by check_streaming_tool_calls.
|
||||
UiRuntime::poll_controller_events(&mut app).expect("poll controller events");
|
||||
assert!(app.has_pending_consent());
|
||||
|
||||
let consent_state = app
|
||||
.consent_dialog()
|
||||
.expect("consent dialog should be visible")
|
||||
.clone();
|
||||
assert_eq!(consent_state.tool_name, "resources/delete");
|
||||
|
||||
// Simulate the user pressing "4" to deny consent.
|
||||
let deny_key = KeyEvent::new(KeyCode::Char('4'), KeyModifiers::NONE);
|
||||
UiRuntime::handle_ui_event(&mut app, Event::Key(deny_key))
|
||||
.await
|
||||
.expect("handle deny key");
|
||||
|
||||
assert!(!app.has_pending_consent());
|
||||
assert!(
|
||||
app.status_message()
|
||||
.to_lowercase()
|
||||
.contains("consent denied")
|
||||
);
|
||||
|
||||
let conversation = app.conversation();
|
||||
let last_message = conversation.messages.last().expect("last message");
|
||||
assert_eq!(last_message.role, Role::Assistant);
|
||||
assert!(
|
||||
last_message
|
||||
.content
|
||||
.to_lowercase()
|
||||
.contains("consent was denied"),
|
||||
"assistant should acknowledge the denied consent"
|
||||
);
|
||||
}
|
||||
216
crates/owlen-tui/tests/generation_tests.rs
Normal file
216
crates/owlen-tui/tests/generation_tests.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Result;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::stream;
|
||||
use owlen_core::provider::{
|
||||
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
||||
ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::state::AppState;
|
||||
use owlen_tui::app::{self, App, MessageState, messages::AppMessage};
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::task::{JoinHandle, yield_now};
|
||||
use tokio::time::advance;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StatusProvider {
|
||||
metadata: ProviderMetadata,
|
||||
status: Arc<Mutex<ProviderStatus>>,
|
||||
chunks: Arc<Vec<GenerateChunk>>,
|
||||
}
|
||||
|
||||
impl StatusProvider {
|
||||
fn new(status: ProviderStatus, chunks: Vec<GenerateChunk>) -> Self {
|
||||
Self {
|
||||
metadata: ProviderMetadata::new("stub", "Stub", ProviderType::Local, false),
|
||||
status: Arc::new(Mutex::new(status)),
|
||||
chunks: Arc::new(chunks),
|
||||
}
|
||||
}
|
||||
|
||||
fn set_status(&self, status: ProviderStatus) {
|
||||
*self.status.lock().unwrap() = status;
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for StatusProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<ProviderStatus, owlen_core::Error> {
|
||||
Ok(*self.status.lock().unwrap())
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, owlen_core::Error> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
||||
async fn generate_stream(
|
||||
&self,
|
||||
_request: GenerateRequest,
|
||||
) -> Result<GenerateStream, owlen_core::Error> {
|
||||
let items = Arc::clone(&self.chunks);
|
||||
let stream_items = items.as_ref().clone();
|
||||
Ok(Box::pin(stream::iter(stream_items.into_iter().map(Ok))))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RecordingState {
|
||||
started: bool,
|
||||
appended: bool,
|
||||
completed: bool,
|
||||
failed: bool,
|
||||
refreshed: bool,
|
||||
updated: bool,
|
||||
provider_status: Option<ProviderStatus>,
|
||||
}
|
||||
|
||||
impl MessageState for RecordingState {
|
||||
fn start_generation(
|
||||
&mut self,
|
||||
_request_id: Uuid,
|
||||
_provider_id: &str,
|
||||
_request: &GenerateRequest,
|
||||
) -> AppState {
|
||||
self.started = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn append_chunk(&mut self, _request_id: Uuid, _chunk: &GenerateChunk) -> AppState {
|
||||
self.appended = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn generation_complete(&mut self, _request_id: Uuid) -> AppState {
|
||||
self.completed = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn generation_failed(&mut self, _request_id: Option<Uuid>, _message: &str) -> AppState {
|
||||
self.failed = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn refresh_model_list(&mut self) -> AppState {
|
||||
self.refreshed = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn update_model_list(&mut self) -> AppState {
|
||||
self.updated = true;
|
||||
AppState::Running
|
||||
}
|
||||
|
||||
fn update_provider_status(&mut self, _provider_id: &str, status: ProviderStatus) -> AppState {
|
||||
self.provider_status = Some(status);
|
||||
AppState::Running
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_and_abort_generation_manage_active_state() {
|
||||
let manager = Arc::new(owlen_core::provider::ProviderManager::default());
|
||||
let provider = StatusProvider::new(
|
||||
ProviderStatus::Available,
|
||||
vec![
|
||||
GenerateChunk::from_text("hello"),
|
||||
GenerateChunk::final_chunk(),
|
||||
],
|
||||
);
|
||||
manager.register_provider(Arc::new(provider.clone())).await;
|
||||
let mut app = App::new(Arc::clone(&manager));
|
||||
|
||||
let request_id = app
|
||||
.start_generation("stub", GenerateRequest::new("stub-model"))
|
||||
.expect("start generation");
|
||||
assert!(app.has_active_generation());
|
||||
assert_ne!(request_id, Uuid::nil());
|
||||
|
||||
app.abort_active_generation();
|
||||
assert!(!app.has_active_generation());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handle_message_dispatches_variants() {
|
||||
let manager = Arc::new(owlen_core::provider::ProviderManager::default());
|
||||
let mut app = App::new(Arc::clone(&manager));
|
||||
let mut state = RecordingState::default();
|
||||
let request_id = Uuid::new_v4();
|
||||
|
||||
let _ = app.handle_message(
|
||||
&mut state,
|
||||
AppMessage::GenerateStart {
|
||||
request_id,
|
||||
provider_id: "stub".into(),
|
||||
request: GenerateRequest::new("stub"),
|
||||
},
|
||||
);
|
||||
let _ = app.handle_message(
|
||||
&mut state,
|
||||
AppMessage::GenerateChunk {
|
||||
request_id,
|
||||
chunk: GenerateChunk::from_text("chunk"),
|
||||
},
|
||||
);
|
||||
let _ = app.handle_message(&mut state, AppMessage::GenerateComplete { request_id });
|
||||
let _ = app.handle_message(
|
||||
&mut state,
|
||||
AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: "error".into(),
|
||||
},
|
||||
);
|
||||
let _ = app.handle_message(&mut state, AppMessage::ModelsRefresh);
|
||||
let _ = app.handle_message(&mut state, AppMessage::ModelsUpdated);
|
||||
let _ = app.handle_message(
|
||||
&mut state,
|
||||
AppMessage::ProviderStatus {
|
||||
provider_id: "stub".into(),
|
||||
status: ProviderStatus::Available,
|
||||
},
|
||||
);
|
||||
|
||||
assert!(state.started);
|
||||
assert!(state.appended);
|
||||
assert!(state.completed);
|
||||
assert!(state.failed);
|
||||
assert!(state.refreshed);
|
||||
assert!(state.updated);
|
||||
assert!(matches!(
|
||||
state.provider_status,
|
||||
Some(ProviderStatus::Available)
|
||||
));
|
||||
}
|
||||
|
||||
#[tokio::test(start_paused = true)]
|
||||
async fn background_worker_emits_status_changes() {
|
||||
let manager = Arc::new(owlen_core::provider::ProviderManager::default());
|
||||
let provider = StatusProvider::new(
|
||||
ProviderStatus::Unavailable,
|
||||
vec![GenerateChunk::final_chunk()],
|
||||
);
|
||||
manager.register_provider(Arc::new(provider.clone())).await;
|
||||
|
||||
let (tx, mut rx) = mpsc::unbounded_channel();
|
||||
let worker: JoinHandle<()> = tokio::spawn(app::background_worker(Arc::clone(&manager), tx));
|
||||
|
||||
provider.set_status(ProviderStatus::Available);
|
||||
advance(Duration::from_secs(31)).await;
|
||||
yield_now().await;
|
||||
|
||||
if let Some(AppMessage::ProviderStatus { status, .. }) = rx.recv().await {
|
||||
assert!(matches!(status, ProviderStatus::Available));
|
||||
} else {
|
||||
panic!("expected provider status update");
|
||||
}
|
||||
|
||||
worker.abort();
|
||||
let _ = worker.await;
|
||||
yield_now().await;
|
||||
}
|
||||
97
crates/owlen-tui/tests/message_tests.rs
Normal file
97
crates/owlen-tui/tests/message_tests.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyEventKind, KeyEventState, KeyModifiers};
|
||||
use owlen_core::provider::{GenerateChunk, GenerateRequest, ProviderStatus};
|
||||
use owlen_tui::app::messages::AppMessage;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[test]
|
||||
fn message_variants_roundtrip_their_data() {
|
||||
let request = GenerateRequest::new("demo-model");
|
||||
let request_id = Uuid::new_v4();
|
||||
let key_event = KeyEvent {
|
||||
code: KeyCode::Char('a'),
|
||||
modifiers: KeyModifiers::CONTROL,
|
||||
kind: KeyEventKind::Press,
|
||||
state: KeyEventState::NONE,
|
||||
};
|
||||
|
||||
let messages = vec![
|
||||
AppMessage::KeyPress(key_event),
|
||||
AppMessage::Resize {
|
||||
width: 120,
|
||||
height: 40,
|
||||
},
|
||||
AppMessage::Tick,
|
||||
AppMessage::GenerateStart {
|
||||
request_id,
|
||||
provider_id: "mock".into(),
|
||||
request: request.clone(),
|
||||
},
|
||||
AppMessage::GenerateChunk {
|
||||
request_id,
|
||||
chunk: GenerateChunk::from_text("hi"),
|
||||
},
|
||||
AppMessage::GenerateComplete { request_id },
|
||||
AppMessage::GenerateError {
|
||||
request_id: Some(request_id),
|
||||
message: "oops".into(),
|
||||
},
|
||||
AppMessage::ModelsRefresh,
|
||||
AppMessage::ModelsUpdated,
|
||||
AppMessage::ProviderStatus {
|
||||
provider_id: "mock".into(),
|
||||
status: ProviderStatus::Available,
|
||||
},
|
||||
];
|
||||
|
||||
for message in messages {
|
||||
match message {
|
||||
AppMessage::KeyPress(event) => {
|
||||
assert_eq!(event.code, KeyCode::Char('a'));
|
||||
assert!(event.modifiers.contains(KeyModifiers::CONTROL));
|
||||
}
|
||||
AppMessage::Resize { width, height } => {
|
||||
assert_eq!(width, 120);
|
||||
assert_eq!(height, 40);
|
||||
}
|
||||
AppMessage::Tick => {}
|
||||
AppMessage::GenerateStart {
|
||||
request_id: id,
|
||||
provider_id,
|
||||
request,
|
||||
} => {
|
||||
assert_eq!(id, request_id);
|
||||
assert_eq!(provider_id, "mock");
|
||||
assert_eq!(request.model, "demo-model");
|
||||
}
|
||||
AppMessage::GenerateChunk {
|
||||
request_id: id,
|
||||
chunk,
|
||||
} => {
|
||||
assert_eq!(id, request_id);
|
||||
assert_eq!(chunk.text.as_deref(), Some("hi"));
|
||||
}
|
||||
AppMessage::GenerateComplete { request_id: id } => {
|
||||
assert_eq!(id, request_id);
|
||||
}
|
||||
AppMessage::GenerateError {
|
||||
request_id: Some(id),
|
||||
message,
|
||||
} => {
|
||||
assert_eq!(id, request_id);
|
||||
assert_eq!(message, "oops");
|
||||
}
|
||||
AppMessage::ModelsRefresh => {}
|
||||
AppMessage::ModelsUpdated => {}
|
||||
AppMessage::ProviderStatus {
|
||||
provider_id,
|
||||
status,
|
||||
} => {
|
||||
assert_eq!(provider_id, "mock");
|
||||
assert!(matches!(status, ProviderStatus::Available));
|
||||
}
|
||||
AppMessage::GenerateError {
|
||||
request_id: None, ..
|
||||
} => panic!("missing request id"),
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user