refactor(all)!: clean out project for v2
This commit is contained in:
@@ -1,23 +0,0 @@
|
||||
[alias]
|
||||
xtask = "run -p xtask --"
|
||||
|
||||
[target.x86_64-unknown-linux-musl]
|
||||
linker = "x86_64-linux-gnu-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
|
||||
[target.aarch64-unknown-linux-musl]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.armv7-unknown-linux-gnueabihf]
|
||||
linker = "arm-linux-gnueabihf-gcc"
|
||||
|
||||
[target.armv7-unknown-linux-musleabihf]
|
||||
linker = "arm-linux-gnueabihf-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.x86_64-pc-windows-gnu]
|
||||
linker = "x86_64-w64-mingw32-gcc"
|
||||
61
.github/workflows/macos-check.yml
vendored
61
.github/workflows/macos-check.yml
vendored
@@ -1,61 +0,0 @@
|
||||
name: macos-check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: cargo check (macOS)
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Cargo check
|
||||
run: cargo check --workspace --all-features
|
||||
|
||||
ollama_regression:
|
||||
name: ollama provider regression
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Run Ollama integration tests
|
||||
run: cargo test -p owlen-core --test ollama_wiremock
|
||||
|
||||
- name: Run streaming/tool flow tests
|
||||
run: cargo test -p owlen-core --test agent_tool_flow
|
||||
107
.gitignore
vendored
107
.gitignore
vendored
@@ -1,107 +0,0 @@
|
||||
### Rust template
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
images/generated/
|
||||
dev/
|
||||
.agents/
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
# RustRover
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
### JetBrains template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
.idea/
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# AWS User-specific
|
||||
.idea/**/aws.xml
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# SonarLint plugin
|
||||
.idea/sonarlint/
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
@@ -1,35 +0,0 @@
|
||||
# Pre-commit hooks configuration
|
||||
# See https://pre-commit.com for more information
|
||||
|
||||
repos:
|
||||
# General file checks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
args: ['--allow-multiple-documents']
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: mixed-line-ending
|
||||
|
||||
# Rust formatting
|
||||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
hooks:
|
||||
- id: fmt
|
||||
name: cargo fmt
|
||||
description: Format Rust code with rustfmt
|
||||
- id: cargo-check
|
||||
name: cargo check
|
||||
description: Check Rust code compilation
|
||||
- id: clippy
|
||||
name: cargo clippy
|
||||
description: Lint Rust code with clippy
|
||||
args: ['--all-features', '--', '-D', 'warnings']
|
||||
|
||||
# Optional: run on all files when config changes
|
||||
default_install_hook_types: [pre-commit, pre-push]
|
||||
197
.woodpecker.yml
197
.woodpecker.yml
@@ -1,197 +0,0 @@
|
||||
---
|
||||
kind: pipeline
|
||||
name: pr-checks
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
steps:
|
||||
- name: fmt-clippy-test
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add rustfmt clippy
|
||||
- cargo fmt --all -- --check
|
||||
- cargo clippy --workspace --all-features -- -D warnings
|
||||
- cargo test --workspace --all-features
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: security-audit
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- cron
|
||||
branch:
|
||||
- dev
|
||||
cron: weekly-security
|
||||
|
||||
steps:
|
||||
- name: cargo-audit
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- cargo install cargo-audit --locked
|
||||
- cargo audit
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release-tests
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
steps:
|
||||
- name: workspace-tests
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
variables:
|
||||
- &rust_image 'rust:1.83'
|
||||
|
||||
depends_on:
|
||||
- release-tests
|
||||
|
||||
matrix:
|
||||
include:
|
||||
# Linux
|
||||
- TARGET: x86_64-unknown-linux-gnu
|
||||
ARTIFACT: owlen-linux-x86_64-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: x86_64-unknown-linux-musl
|
||||
ARTIFACT: owlen-linux-x86_64-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: aarch64-unknown-linux-gnu
|
||||
ARTIFACT: owlen-linux-aarch64-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: aarch64-unknown-linux-musl
|
||||
ARTIFACT: owlen-linux-aarch64-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: armv7-unknown-linux-gnueabihf
|
||||
ARTIFACT: owlen-linux-armv7-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: armv7-unknown-linux-musleabihf
|
||||
ARTIFACT: owlen-linux-armv7-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
# Windows
|
||||
- TARGET: x86_64-pc-windows-gnu
|
||||
ARTIFACT: owlen-windows-x86_64
|
||||
PLATFORM: windows
|
||||
EXT: ".exe"
|
||||
|
||||
steps:
|
||||
- name: build
|
||||
image: *rust_image
|
||||
commands:
|
||||
# Install cross-compilation tools
|
||||
- apt-get update
|
||||
- apt-get install -y musl-tools gcc-aarch64-linux-gnu g++-aarch64-linux-gnu gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf mingw-w64 zip
|
||||
|
||||
# Verify cross-compilers are installed
|
||||
- which aarch64-linux-gnu-gcc || echo "aarch64-linux-gnu-gcc not found!"
|
||||
- which arm-linux-gnueabihf-gcc || echo "arm-linux-gnueabihf-gcc not found!"
|
||||
- which x86_64-w64-mingw32-gcc || echo "x86_64-w64-mingw32-gcc not found!"
|
||||
|
||||
# Add rust target
|
||||
- rustup target add ${TARGET}
|
||||
|
||||
# Set up cross-compilation environment variables and build
|
||||
- |
|
||||
case "${TARGET}" in
|
||||
aarch64-unknown-linux-gnu)
|
||||
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CC_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CXX_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-g++
|
||||
export AR_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-ar
|
||||
;;
|
||||
aarch64-unknown-linux-musl)
|
||||
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_LINKER=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CC_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CXX_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-g++
|
||||
export AR_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-ar
|
||||
;;
|
||||
armv7-unknown-linux-gnueabihf)
|
||||
export CARGO_TARGET_ARMV7_UNKNOWN_LINUX_GNUEABIHF_LINKER=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CC_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CXX_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-g++
|
||||
export AR_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-ar
|
||||
;;
|
||||
armv7-unknown-linux-musleabihf)
|
||||
export CARGO_TARGET_ARMV7_UNKNOWN_LINUX_MUSLEABIHF_LINKER=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CC_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CXX_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-g++
|
||||
export AR_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-ar
|
||||
;;
|
||||
x86_64-pc-windows-gnu)
|
||||
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=/usr/bin/x86_64-w64-mingw32-gcc
|
||||
export CC_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-gcc
|
||||
export CXX_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-g++
|
||||
export AR_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-ar
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build the project
|
||||
cargo build --release --all-features --target ${TARGET}
|
||||
|
||||
- name: package
|
||||
image: *rust_image
|
||||
commands:
|
||||
- apt-get update && apt-get install -y zip
|
||||
- mkdir -p dist
|
||||
- |
|
||||
if [ "${PLATFORM}" = "windows" ]; then
|
||||
cp target/${TARGET}/release/owlen.exe dist/owlen.exe
|
||||
cp target/${TARGET}/release/owlen-code.exe dist/owlen-code.exe
|
||||
cd dist
|
||||
zip -9 ${ARTIFACT}.zip owlen.exe owlen-code.exe
|
||||
cd ..
|
||||
mv dist/${ARTIFACT}.zip .
|
||||
sha256sum ${ARTIFACT}.zip > ${ARTIFACT}.zip.sha256
|
||||
else
|
||||
cp target/${TARGET}/release/owlen dist/owlen
|
||||
cp target/${TARGET}/release/owlen-code dist/owlen-code
|
||||
cd dist
|
||||
tar czf ${ARTIFACT}.tar.gz owlen owlen-code
|
||||
cd ..
|
||||
mv dist/${ARTIFACT}.tar.gz .
|
||||
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
||||
fi
|
||||
|
||||
- name: release-notes
|
||||
image: *rust_image
|
||||
commands:
|
||||
- scripts/release-notes.sh "${CI_COMMIT_TAG}" release-notes.md
|
||||
|
||||
- name: release
|
||||
image: plugins/gitea-release
|
||||
settings:
|
||||
api_key:
|
||||
from_secret: gitea_token
|
||||
base_url: https://somegit.dev
|
||||
files:
|
||||
- ${ARTIFACT}.tar.gz
|
||||
- ${ARTIFACT}.tar.gz.sha256
|
||||
- ${ARTIFACT}.zip
|
||||
- ${ARTIFACT}.zip.sha256
|
||||
title: Release ${CI_COMMIT_TAG}
|
||||
note_file: release-notes.md
|
||||
144
CHANGELOG.md
144
CHANGELOG.md
@@ -1,144 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
||||
- Emacs keymap profile alongside runtime `:keymap` switching between Vim and Emacs layouts.
|
||||
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
||||
- Module-level documentation for `owlen-tui`.
|
||||
- Provider integration tests (`crates/owlen-providers/tests`) covering registration, routing, and health status handling for the new `ProviderManager`.
|
||||
- TUI message and generation tests that exercise the non-blocking event loop, background worker, and message dispatch.
|
||||
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
||||
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
||||
- `owlen config doctor`, `owlen config path`, and `owlen upgrade` CLI commands to automate migrations and surface manual update steps.
|
||||
- Startup provider health check with actionable hints when Ollama or remote MCP servers are unavailable.
|
||||
- `dev/check-windows.sh` helper script for on-demand Windows cross-checks.
|
||||
- Global F1 keybinding for the in-app help overlay and a clearer status hint on launch.
|
||||
- Automatic fallback to the new `ansi_basic` theme when the active terminal only advertises 16-color support.
|
||||
- Offline provider shim that keeps the TUI usable while primary providers are unreachable and communicates recovery steps inline.
|
||||
- `owlen cloud` subcommands (`setup`, `status`, `models`, `logout`) for managing Ollama Cloud credentials without hand-editing config files.
|
||||
- Tabbed model selector that separates local and cloud providers, including cloud indicators in the UI.
|
||||
- Footer status line includes provider connectivity/credential summaries (e.g., cloud auth failures, missing API keys).
|
||||
- Secure credential vault integration for Ollama Cloud API keys when `privacy.encrypt_local_data = true`.
|
||||
- Input panel respects a new `ui.input_max_rows` setting so long prompts expand predictably before scrolling kicks in.
|
||||
- Adaptive TUI layout with responsive 80/120-column breakpoints, refreshed glass/neon theming, and animated focus rings for pane transitions.
|
||||
- Configurable `ui.layers` and `ui.animations` settings to tune glass elevation, neon intensity, and opt-in micro-animations.
|
||||
- Adaptive transcript compactor with configurable auto mode, CLI opt-out (`--no-auto-compress`), and `:compress` commands for manual runs and toggling.
|
||||
- Command palette offers fuzzy `:model` filtering and `:provider` completions for fast switching.
|
||||
- Inline guidance overlay adds a three-step onboarding tour, keymap-aware cheat sheets (F1 / `?`), and persists completion state via `ui.guidance`.
|
||||
- Status surface renders a layered HUD with streaming/tool indicators, contextual gauges, and redesigned toast cards featuring icons, countdown timers, and a compact history log.
|
||||
- Published a TUI UX & keybinding playbook documenting modal ergonomics, command metadata, theming tokens, and animation policy.
|
||||
- Automated TUI regression snapshots now cover mode transitions, keymap variants, accessibility presets, and multiple terminal breakpoints.
|
||||
- Cloud usage tracker persists hourly/weekly token totals, adds a `:limits` command, shows live header badges, and raises toast warnings at 80 %/95 % of the configured quotas.
|
||||
- Message rendering caches wrapped lines and throttles streaming redraws to keep the TUI responsive on long sessions.
|
||||
- Model picker badges now inspect provider capabilities so vision/audio/thinking models surface the correct icons even when descriptions are sparse.
|
||||
- Chat history honors `ui.scrollback_lines`, trimming older rows to keep the TUI responsive and surfacing a "↓ New messages" badge whenever updates land off-screen.
|
||||
|
||||
### Changed
|
||||
- The main `README.md` has been updated to be more concise and link to the new documentation.
|
||||
- Default configuration now pre-populates both `providers.ollama` and `providers.ollama-cloud` entries so switching between local and cloud backends is a single setting change.
|
||||
- `McpMode` support was restored with explicit validation; `remote_only`, `remote_preferred`, and `local_only` now behave predictably.
|
||||
- Configuration loading performs structural validation and fails fast on missing default providers or invalid MCP definitions.
|
||||
- Ollama provider error handling now distinguishes timeouts, missing models, and authentication failures.
|
||||
- The `web_search` tool now proxies through Ollama Cloud’s `/api/web_search` endpoint and is hidden whenever the active provider cannot reach the cloud. The legacy `web.search` alias stays enabled for older sessions.
|
||||
- `owlen` warns when the active terminal likely lacks 256-color support.
|
||||
- `config.toml` now carries a schema version (`1.2.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||
- Model selector navigation (Tab/Shift-Tab) now switches between local and cloud tabs while preserving selection state.
|
||||
- Header displays the active model together with its provider (e.g., `Model (Provider)`), improving clarity when swapping backends.
|
||||
- Documentation refreshed to cover the message handler architecture, the background health worker, multi-provider configuration, and the new provider onboarding checklist.
|
||||
|
||||
---
|
||||
|
||||
## [0.2.0] - 2025-10-24
|
||||
|
||||
### Added
|
||||
- Cloud usage tracker now persists hourly and weekly token totals, exposes a `:limits` command, and renders live gradient gauges in the header with 80 %/95 % toast notifications.
|
||||
- Web search tooling is available whenever Ollama Cloud is configured, giving the assistant automatic access to the `web_search` function with runtime toggles via `:web on|off`. Legacy dotted references continue to resolve through the alias layer.
|
||||
- Provider registry aggregates local and cloud Ollama models, including health checks, scope badges, and graceful fallback between providers.
|
||||
- Release documentation covers the migration from v0.1, including the new config schema defaults, cloud setup guide, and troubleshooting steps for common errors.
|
||||
|
||||
### Changed
|
||||
- Workspace packages, distribution metadata, and README badges now report version `0.2.0`.
|
||||
- Chat header adopts a cockpit layout powered by Ratatui 0.29 flex layouts and Tailwind-inspired gradients, clearly surfacing context and quota usage.
|
||||
- Cloud requests now default to the canonical `https://ollama.com` endpoint and automatically attach the `Authorization: Bearer <API_KEY>` header resolved from config or environment variables.
|
||||
- Configuration templates enable both local (`providers.ollama`) and cloud (`providers.ollama_cloud`) entries by default, complete with TTLs, context windows, and quota placeholders.
|
||||
|
||||
### Fixed
|
||||
- Selecting Ollama Cloud without a valid API key now surfaces actionable unauthorized toasts and falls back to the last working local provider instead of looping 401 responses.
|
||||
- Rate-limited cloud responses raise non-fatal warnings so sessions remain usable while the usage tracker records the incident.
|
||||
|
||||
---
|
||||
|
||||
## [0.1.11] - 2025-10-18
|
||||
|
||||
### Changed
|
||||
- Bump workspace packages and distribution metadata to version `0.1.11`.
|
||||
|
||||
## [0.1.10] - 2025-10-03
|
||||
|
||||
### Added
|
||||
- **Material Light Theme**: A new built-in theme, `material-light`, has been added.
|
||||
|
||||
### Fixed
|
||||
- **UI Readability**: Fixed a bug causing unreadable text in light themes.
|
||||
- **Visual Selection**: The visual selection mode now correctly colors unselected text portions.
|
||||
|
||||
### Changed
|
||||
- **Theme Colors**: The color palettes for `gruvbox`, `rose-pine`, and `monokai` have been corrected.
|
||||
- **In-App Help**: The `:help` menu has been significantly expanded and updated.
|
||||
|
||||
## [0.1.9] - 2025-10-03
|
||||
|
||||
*This version corresponds to the release tagged v0.1.10 in the source repository.*
|
||||
|
||||
### Added
|
||||
- **Material Light Theme**: A new built-in theme, `material-light`, has been added.
|
||||
|
||||
### Fixed
|
||||
- **UI Readability**: Fixed a bug causing unreadable text in light themes.
|
||||
- **Visual Selection**: The visual selection mode now correctly colors unselected text portions.
|
||||
|
||||
### Changed
|
||||
- **Theme Colors**: The color palettes for `gruvbox`, `rose-pine`, and `monokai` have been corrected.
|
||||
- **In-App Help**: The `:help` menu has been significantly expanded and updated.
|
||||
|
||||
## [0.1.8] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Command Autocompletion**: Implemented intelligent command suggestions and Tab completion in command mode.
|
||||
|
||||
### Changed
|
||||
- **Build & CI**: Fixed cross-compilation for ARM64, ARMv7, and Windows.
|
||||
|
||||
## [0.1.7] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Tabbed Help System**: The help menu is now organized into five tabs for easier navigation.
|
||||
- **Command Aliases**: Added `:o` as a short alias for `:load` / `:open`.
|
||||
|
||||
### Changed
|
||||
- **Session Management**: Improved AI-generated session descriptions.
|
||||
|
||||
## [0.1.6] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Platform-Specific Storage**: Sessions are now saved to platform-appropriate directories (e.g., `~/.local/share/owlen` on Linux).
|
||||
- **AI-Generated Session Descriptions**: Conversations can be automatically summarized on save.
|
||||
|
||||
### Changed
|
||||
- **Migration**: Users on older versions can manually move their sessions from `~/.config/owlen/sessions` to the new platform-specific directory.
|
||||
|
||||
## [0.1.4] - 2025-10-01
|
||||
|
||||
### Added
|
||||
- **Multi-Platform Builds**: Pre-built binaries are now provided for Linux (x86_64, aarch64, armv7) and Windows (x86_64).
|
||||
- **AUR Package**: Owlen is now available on the Arch User Repository.
|
||||
|
||||
### Changed
|
||||
- **Build System**: Switched from OpenSSL to rustls for better cross-platform compatibility.
|
||||
309
CLAUDE.md
309
CLAUDE.md
@@ -1,309 +0,0 @@
|
||||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud language models. It uses a multi-provider architecture with vim-style navigation and session management.
|
||||
|
||||
**Status**: Alpha (v0.2.0) - core features functional but expect occasional bugs and breaking changes.
|
||||
|
||||
## Build, Test & Development Commands
|
||||
|
||||
### Building
|
||||
```bash
|
||||
# Build all crates
|
||||
cargo build
|
||||
|
||||
# Build release binary
|
||||
cargo build --release
|
||||
|
||||
# Run the TUI (requires Ollama running)
|
||||
./target/release/owlen
|
||||
# or
|
||||
cargo run -p owlen-cli
|
||||
|
||||
# Build for specific target (cross-compilation)
|
||||
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||
```
|
||||
|
||||
### Testing
|
||||
```bash
|
||||
# Run all tests
|
||||
cargo test --all
|
||||
|
||||
# Test specific crate
|
||||
cargo test -p owlen-core
|
||||
cargo test -p owlen-tui
|
||||
cargo test -p owlen-providers
|
||||
|
||||
# Linting and formatting
|
||||
cargo clippy --all -- -D warnings
|
||||
cargo fmt --all -- --check
|
||||
|
||||
# Pre-commit hooks (install once with `pre-commit install`)
|
||||
pre-commit run --all-files
|
||||
```
|
||||
|
||||
### Developer Tasks
|
||||
```bash
|
||||
# Regenerate screenshots for documentation
|
||||
cargo xtask screenshots
|
||||
cargo xtask screenshots --no-png # skip PNG generation
|
||||
cargo xtask screenshots --output images/
|
||||
|
||||
# Regenerate repository map after structural changes
|
||||
scripts/gen-repo-map.sh
|
||||
|
||||
# Platform compatibility checks
|
||||
scripts/check-windows.sh # Windows GNU toolchain smoke test
|
||||
```
|
||||
|
||||
### Running Individual Tests
|
||||
```bash
|
||||
# Run a specific test by name
|
||||
cargo test test_name
|
||||
|
||||
# Run tests with output
|
||||
cargo test -- --nocapture
|
||||
|
||||
# Run tests in a specific file
|
||||
cargo test --test integration_test_name
|
||||
```
|
||||
|
||||
## Architecture & Key Concepts
|
||||
|
||||
### Workspace Structure (Cargo workspace with 13+ crates)
|
||||
- **owlen-core**: Core abstractions, provider traits, session management, MCP client layer (UI-agnostic)
|
||||
- **owlen-tui**: Terminal UI built with ratatui (event loop, rendering, vim modes)
|
||||
- **owlen-cli**: Entry point that parses args, loads config, launches TUI or headless flows
|
||||
- **owlen-providers**: Concrete provider adapters (Ollama local, Ollama Cloud)
|
||||
- **owlen-markdown**: Markdown parsing and rendering
|
||||
- **crates/mcp/**: Model Context Protocol infrastructure
|
||||
- **llm-server**: Wraps owlen-providers behind MCP boundary (generate_text tools)
|
||||
- **server**: Generic MCP server for file ops and workspace tools
|
||||
- **client**: MCP client implementation
|
||||
- **code-server**: Code execution sandboxing
|
||||
- **prompt-server**: Template rendering
|
||||
- **xtask**: Development automation tasks (screenshots, etc.)
|
||||
|
||||
### Dependency Boundaries
|
||||
- **owlen-core is the dependency ceiling**: Must stay free of terminal logic, CLIs, or provider HTTP clients
|
||||
- **owlen-cli only orchestrates startup/shutdown**: Business logic belongs in owlen-core or library crates
|
||||
- **owlen-mcp-llm-server is the only crate that directly talks to providers**: UI/CLI communicate through MCP clients
|
||||
|
||||
### Multi-Provider Architecture
|
||||
```
|
||||
[owlen-tui / owlen-cli]
|
||||
│
|
||||
│ chat + model requests
|
||||
▼
|
||||
[owlen-core::ProviderManager] ──> Arc<dyn ModelProvider>
|
||||
│ ▲
|
||||
│ │ implements ModelProvider
|
||||
▼ │
|
||||
[owlen-core::mcp::RemoteMcpClient] ────────┘
|
||||
│ (JSON-RPC over stdio)
|
||||
▼
|
||||
┌────────────────────────────────────────────────┐
|
||||
│ MCP Process Boundary (spawned per provider) │
|
||||
│ │
|
||||
│ crates/mcp/llm-server ──> owlen-providers::* │
|
||||
└────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
Key points:
|
||||
- **ProviderManager** tracks health, merges model catalogs, and dispatches requests
|
||||
- **RemoteMcpClient** bridges MCP protocol to ModelProvider trait
|
||||
- **MCP servers** isolate provider-specific code in separate processes
|
||||
- **Health & availability** tracked via background workers and surfaced in TUI picker
|
||||
|
||||
### Event Flow & TUI Architecture
|
||||
1. User input → Event loop → Message handler → Session controller → Provider manager → Provider
|
||||
2. Non-blocking design: TUI remains responsive during streaming (see `agents.md` for planned improvements)
|
||||
3. Modal workflow: Normal, Insert, Visual, Command modes (vim-inspired)
|
||||
4. AppMessage stream carries async events (provider responses, health checks)
|
||||
|
||||
### Session & Conversation Management
|
||||
- **Conversation** (owlen-core): Holds messages and metadata
|
||||
- **SessionController**: High-level orchestrator managing history, context, model switching
|
||||
- Conversations stored in platform-specific data directory (can be encrypted with AES-GCM)
|
||||
|
||||
### Configuration
|
||||
Platform-specific locations:
|
||||
- Linux: `~/.config/owlen/config.toml`
|
||||
- macOS: `~/Library/Application Support/owlen/config.toml`
|
||||
- Windows: `%APPDATA%\owlen\config.toml`
|
||||
|
||||
Commands:
|
||||
```bash
|
||||
owlen config init # Create default config
|
||||
owlen config init --force # Overwrite existing
|
||||
owlen config path # Print config location
|
||||
owlen config doctor # Migrate legacy configs
|
||||
```
|
||||
|
||||
## Coding Conventions
|
||||
|
||||
### Commit Messages
|
||||
Follow [Conventional Commits](https://www.conventionalcommits.org/):
|
||||
```
|
||||
<type>[optional scope]: <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer(s)]
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`, `build`, `ci`
|
||||
|
||||
Example: `feat(provider): add support for Gemini Pro`
|
||||
|
||||
### Pre-commit Hooks
|
||||
Hooks automatically run on commit (install with `pre-commit install`):
|
||||
- `cargo fmt`
|
||||
- `cargo check`
|
||||
- `cargo clippy --all-features`
|
||||
- File hygiene (trailing whitespace, EOF newlines)
|
||||
|
||||
To bypass (not recommended): `git commit --no-verify`
|
||||
|
||||
### Style Guidelines
|
||||
- Run `cargo fmt` before committing
|
||||
- Address all `cargo clippy` warnings
|
||||
- Use `#[cfg(test)]` modules for unit tests in same file
|
||||
- Place integration tests in `tests/` directory
|
||||
|
||||
## Provider Development
|
||||
|
||||
### Adding a New Provider
|
||||
Follow `docs/adding-providers.md`:
|
||||
1. Implement `ModelProvider` trait in `owlen-providers`
|
||||
2. Set `ProviderMetadata::provider_type` (Local/Cloud)
|
||||
3. Register with `ProviderManager` in startup code
|
||||
4. Optionally expose through MCP server
|
||||
5. Add integration tests following `crates/owlen-providers/tests` pattern
|
||||
6. Document config in `docs/configuration.md` and default `config.toml`
|
||||
7. Update `README.md`, `CHANGELOG.md`, `docs/troubleshooting.md`
|
||||
|
||||
See `docs/provider-implementation.md` for trait-level details.
|
||||
|
||||
### MCP Tool Naming
|
||||
Enforce spec-compliant identifiers: `^[A-Za-z0-9_-]{1,64}$`
|
||||
- Use underscores or hyphens (e.g., `web_search`, `filesystem_read`)
|
||||
- Avoid dotted names (legacy incompatible)
|
||||
- Qualify with `{server}__{tool}` when multiple servers overlap (e.g., `filesystem__read`)
|
||||
|
||||
## Repository Automation
|
||||
|
||||
OWLEN includes Git-aware automation for code review and commit templating:
|
||||
|
||||
### CLI Commands
|
||||
```bash
|
||||
# Generate commit message from staged diff
|
||||
owlen repo commit-template
|
||||
owlen repo commit-template --working-tree # inspect unstaged
|
||||
|
||||
# Review branch or PR
|
||||
owlen repo review
|
||||
owlen repo review --owner Owlibou --repo owlen --number 42 --token-env GITHUB_TOKEN
|
||||
```
|
||||
|
||||
### TUI Commands
|
||||
```
|
||||
:repo template # inject commit template into chat
|
||||
:repo review [--base BRANCH] [--head REF] # review local changes
|
||||
```
|
||||
|
||||
## Key Files & Entry Points
|
||||
|
||||
### Main Entry Points
|
||||
- `crates/owlen-cli/src/main.rs` - CLI entry point (argument parsing, config loading)
|
||||
- `crates/owlen-tui/src/app/mod.rs` - Main TUI application and event dispatch
|
||||
- `crates/owlen-core/src/provider.rs` - ModelProvider trait definition
|
||||
|
||||
### Configuration & State
|
||||
- `crates/owlen-core/src/config.rs` - Configuration loading and parsing
|
||||
- `crates/owlen-core/src/session.rs` - Session and conversation management
|
||||
- `crates/owlen-core/src/storage.rs` - Persistence layer
|
||||
|
||||
### Provider Infrastructure
|
||||
- `crates/owlen-providers/src/ollama/` - Ollama local and cloud providers
|
||||
- `crates/mcp/llm-server/src/main.rs` - MCP LLM server process
|
||||
- `crates/owlen-core/src/mcp/remote_client.rs` - MCP client implementation
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
### Unit Tests
|
||||
Place in `#[cfg(test)]` modules within source files for isolated component testing.
|
||||
|
||||
### Integration Tests
|
||||
Place in `tests/` directories:
|
||||
- `crates/owlen-providers/tests/` - Provider integration tests
|
||||
- Test registration, model aggregation, request routing, health transitions
|
||||
|
||||
### Focus Areas
|
||||
- Command palette state machine
|
||||
- Agent response parsing
|
||||
- MCP protocol abstractions
|
||||
- Provider manager health cache
|
||||
- Session controller lifecycle
|
||||
|
||||
## Documentation Structure
|
||||
|
||||
- `README.md` - User-facing overview, installation, features
|
||||
- `CONTRIBUTING.md` - Contribution guidelines, development setup
|
||||
- `docs/architecture.md` - High-level architecture (read first!)
|
||||
- `docs/repo-map.md` - Workspace layout snapshot
|
||||
- `docs/adding-providers.md` - Provider implementation checklist
|
||||
- `docs/provider-implementation.md` - Trait-level provider details
|
||||
- `docs/testing.md` - Testing guide
|
||||
- `docs/troubleshooting.md` - Common issues and solutions
|
||||
- `docs/configuration.md` - Configuration reference
|
||||
- `docs/platform-support.md` - OS support matrix
|
||||
|
||||
## Important Implementation Notes
|
||||
|
||||
### When Working on TUI Code
|
||||
- Modal state machine is critical: Normal ↔ Insert ↔ Visual ↔ Command
|
||||
- Status line shows current mode (use as regression check)
|
||||
- Non-blocking event loop planned (see `agents.md`)
|
||||
- Command palette state lives in `owlen_tui::state`
|
||||
- Follow Model-View-Update pattern for new features
|
||||
|
||||
### When Working on Providers
|
||||
- Never import providers directly in owlen-tui or owlen-cli
|
||||
- All provider communication goes through owlen-core abstractions
|
||||
- Health checks run on background workers
|
||||
- Model discovery fans out through ProviderManager
|
||||
|
||||
### When Working on MCP Integration
|
||||
- RemoteMcpClient implements both MCP client traits and ModelProvider
|
||||
- MCP servers are short-lived, narrowly scoped binaries
|
||||
- Tool calls travel same transport as chat requests
|
||||
- Consent prompts surface in UI via session events
|
||||
|
||||
## Platform Support
|
||||
|
||||
- **Primary**: Linux (Arch AUR: `owlen-git`)
|
||||
- **Supported**: macOS 12+ (requires Command Line Tools for OpenSSL)
|
||||
- **Experimental**: Windows (GNU toolchain, some Docker features disabled)
|
||||
|
||||
Cross-platform testing: Use `dev/local_build.sh` and `scripts/check-windows.sh`
|
||||
|
||||
## Dependencies & Async Runtime
|
||||
|
||||
- **Async runtime**: tokio with "full" features
|
||||
- **TUI framework**: ratatui 0.29 with palette features
|
||||
- **HTTP client**: reqwest with rustls-tls (no native-tls)
|
||||
- **Database**: SQLx with sqlite, tokio runtime
|
||||
- **Serialization**: serde + serde_json
|
||||
- **Testing**: tokio-test for async test utilities
|
||||
|
||||
## Security & Privacy
|
||||
|
||||
- Local-first: LLM calls route through local Ollama by default
|
||||
- Session encryption: Set `privacy.encrypt_local_data = true` for AES-GCM storage
|
||||
- No telemetry sent
|
||||
- Outbound requests only when explicitly enabling remote tools/providers
|
||||
- Config migrations carry schema version and warn on deprecated keys
|
||||
@@ -1,121 +0,0 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that are welcoming, open, and respectful.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
[security@owlibou.com](mailto:security@owlibou.com). All complaints will be
|
||||
reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interaction in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
126
CONTRIBUTING.md
126
CONTRIBUTING.md
@@ -1,126 +0,0 @@
|
||||
# Contributing to Owlen
|
||||
|
||||
First off, thank you for considering contributing to Owlen! It's people like you that make Owlen such a great tool.
|
||||
|
||||
Following these guidelines helps to communicate that you respect the time of the developers managing and developing this open source project. In return, they should reciprocate that respect in addressing your issue, assessing changes, and helping you finalize your pull requests.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project and everyone participating in it is governed by the [Owlen Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior.
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
### Repository map
|
||||
|
||||
Need a quick orientation before diving in? Start with the curated [repo map](docs/repo-map.md) for a two-level directory overview. If you move folders around, regenerate it with `scripts/gen-repo-map.sh`.
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
||||
|
||||
1. **Check the [troubleshooting guide](docs/troubleshooting.md).** Your issue might be a common one with a known solution.
|
||||
2. **Search the existing issues.** It's possible someone has already reported the same bug. If so, add a comment to the existing issue instead of creating a new one.
|
||||
|
||||
When you are creating a bug report, please include as many details as possible. Fill out the required template, the information it asks for helps us resolve issues faster.
|
||||
|
||||
### Suggesting Enhancements
|
||||
|
||||
If you have an idea for a new feature or an improvement to an existing one, we'd love to hear about it. Please provide as much context as you can about what you're trying to achieve.
|
||||
|
||||
### Your First Code Contribution
|
||||
|
||||
Unsure where to begin contributing to Owlen? You can start by looking through `good first issue` and `help wanted` issues.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
The process for submitting a pull request is as follows:
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`.
|
||||
2. **Set up pre-commit hooks** (see [Development Setup](#development-setup) above). This will automatically format and lint your code.
|
||||
3. **Make your changes.**
|
||||
4. **Run the tests.**
|
||||
- `cargo test --all`
|
||||
5. **Commit your changes.** The pre-commit hooks will automatically run `cargo fmt`, `cargo check`, and `cargo clippy`. If you need to bypass the hooks (not recommended), use `git commit --no-verify`.
|
||||
6. **Add a clear, concise commit message.** We follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
||||
7. **Push to your fork** and submit a pull request to Owlen's `main` branch.
|
||||
8. **Include a clear description** of the problem and solution. Include the relevant issue number if applicable.
|
||||
9. **Declare AI assistance.** If any part of the patch was generated with an AI tool (e.g., ChatGPT, Claude Code), call that out in the PR description. A human maintainer must review and approve AI-assisted changes before merge.
|
||||
|
||||
## Development Setup
|
||||
|
||||
To get started with the codebase, you'll need to have Rust installed. Then, you can clone the repository and build the project:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/Owlibou/owlen.git
|
||||
cd owlen
|
||||
cargo build
|
||||
```
|
||||
|
||||
### Pre-commit Hooks
|
||||
|
||||
We use [pre-commit](https://pre-commit.com/) to automatically run formatting and linting checks before each commit. This helps maintain code quality and consistency.
|
||||
|
||||
**Install pre-commit:**
|
||||
|
||||
```sh
|
||||
# Arch Linux
|
||||
sudo pacman -S pre-commit
|
||||
|
||||
# Other Linux/macOS
|
||||
pip install pre-commit
|
||||
|
||||
# Verify installation
|
||||
pre-commit --version
|
||||
```
|
||||
|
||||
**Setup the hooks:**
|
||||
|
||||
```sh
|
||||
cd owlen
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Once installed, the hooks will automatically run on every commit. You can also run them manually:
|
||||
|
||||
```sh
|
||||
# Run on all files
|
||||
pre-commit run --all-files
|
||||
|
||||
# Run on staged files only
|
||||
pre-commit run
|
||||
```
|
||||
|
||||
The pre-commit hooks will check:
|
||||
- Code formatting (`cargo fmt`)
|
||||
- Compilation (`cargo check`)
|
||||
- Linting (`cargo clippy --all-features`)
|
||||
- General file hygiene (trailing whitespace, EOF newlines, etc.)
|
||||
|
||||
## Coding Style
|
||||
|
||||
- We use `cargo fmt` for automated code formatting. Please run it before committing your changes.
|
||||
- We use `cargo clippy` for linting. Your code should be free of any clippy warnings.
|
||||
|
||||
## Commit Message Conventions
|
||||
|
||||
We use [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) for our commit messages. This allows for automated changelog generation and makes the project history easier to read.
|
||||
|
||||
The basic format is:
|
||||
|
||||
```
|
||||
<type>[optional scope]: <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer(s)]
|
||||
```
|
||||
|
||||
**Types:** `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`, `build`, `ci`.
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
feat(provider): add support for Gemini Pro
|
||||
```
|
||||
|
||||
Thank you for your contribution!
|
||||
90
Cargo.toml
90
Cargo.toml
@@ -1,90 +0,0 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/owlen-core",
|
||||
"crates/owlen-ui-common",
|
||||
"crates/owlen-tui",
|
||||
"crates/owlen-cli",
|
||||
"crates/owlen-providers",
|
||||
"crates/mcp/server",
|
||||
"crates/mcp/llm-server",
|
||||
"crates/mcp/client",
|
||||
"crates/mcp/code-server",
|
||||
"crates/mcp/prompt-server",
|
||||
"crates/owlen-markdown",
|
||||
"xtask",
|
||||
]
|
||||
exclude = []
|
||||
|
||||
[workspace.package]
|
||||
version = "0.2.0"
|
||||
edition = "2024"
|
||||
authors = ["Owlibou"]
|
||||
license = "AGPL-3.0"
|
||||
repository = "https://somegit.dev/Owlibou/owlen"
|
||||
homepage = "https://somegit.dev/Owlibou/owlen"
|
||||
keywords = ["llm", "tui", "cli", "ollama", "chat"]
|
||||
categories = ["command-line-utilities"]
|
||||
|
||||
[workspace.dependencies]
|
||||
# Async runtime and utilities
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
|
||||
# TUI framework
|
||||
ratatui = { version = "0.29", features = ["palette"] }
|
||||
crossterm = "0.28.1"
|
||||
tui-textarea = "0.7"
|
||||
|
||||
# HTTP client and JSON handling
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { version = "1.0" }
|
||||
|
||||
# Utilities
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
nix = "0.29"
|
||||
which = "6.0"
|
||||
tempfile = "3.8"
|
||||
jsonschema = "0.17"
|
||||
aes-gcm = "0.10"
|
||||
ring = ">=0.17.12" # Security fix for CVE in 0.17.9 (AES panic vulnerability)
|
||||
keyring = "3.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
urlencoding = "2.1"
|
||||
regex = "1.10"
|
||||
sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio", "tls-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
||||
log = "0.4"
|
||||
dirs = "5.0"
|
||||
serde_yaml = "0.9"
|
||||
handlebars = "6.0"
|
||||
once_cell = "1.19"
|
||||
base64 = "0.22"
|
||||
image = { version = "0.25", default-features = false, features = ["png", "jpeg", "gif", "bmp", "webp"] }
|
||||
mime_guess = "2.0"
|
||||
|
||||
# Configuration
|
||||
toml = "0.8"
|
||||
shellexpand = "3.1"
|
||||
|
||||
# Database
|
||||
sled = "0.34"
|
||||
|
||||
# For better text handling
|
||||
textwrap = "0.16"
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# CLI framework
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
|
||||
# Dev dependencies
|
||||
tokio-test = "0.4"
|
||||
|
||||
# For more keys and their definitions, see https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
661
LICENSE
661
LICENSE
@@ -1,661 +0,0 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
49
PKGBUILD
49
PKGBUILD
@@ -1,49 +0,0 @@
|
||||
# Maintainer: vikingowl <christian@nachtigall.dev>
|
||||
pkgname=owlen
|
||||
pkgver=0.2.0
|
||||
pkgrel=1
|
||||
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
||||
arch=('x86_64')
|
||||
url="https://somegit.dev/Owlibou/owlen"
|
||||
license=('AGPL-3.0-or-later')
|
||||
depends=('gcc-libs')
|
||||
makedepends=('cargo' 'git')
|
||||
options=(!lto) # avoid LTO-linked ring symbol drop with lld
|
||||
source=("$pkgname-$pkgver.tar.gz::$url/archive/v$pkgver.tar.gz")
|
||||
sha256sums=('cabb1cfdfc247b5d008c6c5f94e13548bcefeba874aae9a9d45aa95ae1c085ec')
|
||||
|
||||
prepare() {
|
||||
cd $pkgname
|
||||
cargo fetch --target "$(rustc -vV | sed -n 's/host: //p')"
|
||||
}
|
||||
|
||||
build() {
|
||||
cd $pkgname
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -C link-arg=-Wl,--no-as-needed"
|
||||
export CARGO_PROFILE_RELEASE_LTO=false
|
||||
export CARGO_TARGET_DIR=target
|
||||
cargo build --frozen --release --all-features
|
||||
}
|
||||
|
||||
check() {
|
||||
cd $pkgname
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -C link-arg=-Wl,--no-as-needed"
|
||||
cargo test --frozen --all-features
|
||||
}
|
||||
|
||||
package() {
|
||||
cd $pkgname
|
||||
|
||||
# Install binaries
|
||||
install -Dm755 target/release/owlen "$pkgdir/usr/bin/owlen"
|
||||
install -Dm755 target/release/owlen-code "$pkgdir/usr/bin/owlen-code"
|
||||
|
||||
# Install documentation
|
||||
install -Dm644 README.md "$pkgdir/usr/share/doc/$pkgname/README.md"
|
||||
|
||||
# Install built-in themes for reference
|
||||
install -Dm644 themes/README.md "$pkgdir/usr/share/$pkgname/themes/README.md"
|
||||
for theme in themes/*.toml; do
|
||||
install -Dm644 "$theme" "$pkgdir/usr/share/$pkgname/themes/$(basename $theme)"
|
||||
done
|
||||
}
|
||||
@@ -1,400 +0,0 @@
|
||||
# ProviderManager Clone Overhead Optimizations
|
||||
|
||||
## Summary
|
||||
|
||||
This document describes the optimizations applied to `/home/cnachtigall/data/git/projects/Owlibou/owlen/crates/owlen-core/src/provider/manager.rs` to reduce clone overhead as identified in the project analysis report.
|
||||
|
||||
## Problems Identified
|
||||
|
||||
1. **Lines 94-100** (`list_all_models`): Clones all provider Arc handles and IDs unnecessarily into an intermediate Vec
|
||||
2. **Lines 162-168** (`refresh_health`): Collects into Vec with unnecessary clones before spawning async tasks
|
||||
3. **Line 220** (`provider_statuses()`): Clones entire HashMap on every call
|
||||
|
||||
The report estimated that 15-20% of `list_all_models` time was spent on String clones alone.
|
||||
|
||||
## Optimizations Applied
|
||||
|
||||
### 1. Change `status_cache` to Arc-Wrapped HashMap
|
||||
|
||||
**File**: `crates/owlen-core/src/provider/manager.rs`
|
||||
|
||||
**Line 28**: Change struct definition
|
||||
```rust
|
||||
// Before:
|
||||
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||
|
||||
// After:
|
||||
status_cache: RwLock<Arc<HashMap<String, ProviderStatus>>>,
|
||||
```
|
||||
|
||||
**Rationale**: Using `Arc<HashMap>` allows cheap cloning via reference counting instead of deep-copying the entire HashMap.
|
||||
|
||||
### 2. Update Constructor (`new`)
|
||||
|
||||
**Lines 41-44**:
|
||||
```rust
|
||||
// Before:
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(status_cache),
|
||||
}
|
||||
|
||||
// After:
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(Arc::new(status_cache)),
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Update Default Implementation
|
||||
|
||||
**Lines 476-479**:
|
||||
```rust
|
||||
// Before:
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(HashMap::new()),
|
||||
}
|
||||
|
||||
// After:
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(Arc::new(HashMap::new())),
|
||||
}
|
||||
```
|
||||
|
||||
### 4. Update `register_provider` (Copy-on-Write Pattern)
|
||||
|
||||
**Lines 56-59**:
|
||||
```rust
|
||||
// Before:
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id, ProviderStatus::Unavailable);
|
||||
|
||||
// After:
|
||||
// Update status cache with copy-on-write
|
||||
let mut guard = self.status_cache.write().await;
|
||||
let mut new_cache = (**guard).clone();
|
||||
new_cache.insert(provider_id, ProviderStatus::Unavailable);
|
||||
*guard = Arc::new(new_cache);
|
||||
```
|
||||
|
||||
**Rationale**: When updating the HashMap, we clone the inner HashMap (not the Arc), modify it, then wrap in a new Arc. This keeps the immutability contract while allowing readers to continue using old snapshots.
|
||||
|
||||
### 5. Update `generate` Method (Two Locations)
|
||||
|
||||
**Lines 76-79** (Available status):
|
||||
```rust
|
||||
// Before:
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
|
||||
// After:
|
||||
// Update status cache with copy-on-write
|
||||
let mut guard = self.status_cache.write().await;
|
||||
let mut new_cache = (**guard).clone();
|
||||
new_cache.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
*guard = Arc::new(new_cache);
|
||||
```
|
||||
|
||||
**Lines 83-86** (Unavailable status):
|
||||
```rust
|
||||
// Before:
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
|
||||
// After:
|
||||
// Update status cache with copy-on-write
|
||||
let mut guard = self.status_cache.write().await;
|
||||
let mut new_cache = (**guard).clone();
|
||||
new_cache.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
*guard = Arc::new(new_cache);
|
||||
```
|
||||
|
||||
### 6. Update `list_all_models` (Avoid Intermediate Vec)
|
||||
|
||||
**Lines 94-132**:
|
||||
```rust
|
||||
// Before:
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let log_id = provider_id.clone();
|
||||
// ...
|
||||
});
|
||||
}
|
||||
|
||||
// After:
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
{
|
||||
let guard = self.providers.read().await;
|
||||
for (provider_id, provider) in guard.iter() {
|
||||
// Clone Arc and String, but keep lock held for shorter duration
|
||||
let provider_id = provider_id.clone();
|
||||
let provider = Arc::clone(provider);
|
||||
|
||||
tasks.push(async move {
|
||||
// No need for log_id clone - just use provider_id directly
|
||||
// ...
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Rationale**:
|
||||
- Eliminates intermediate Vec allocation
|
||||
- Still clones provider_id and Arc, but does so inline during iteration
|
||||
- Lock is held only during spawning (which is fast), not during actual health checks
|
||||
- Removes unnecessary `log_id` clone inside async block
|
||||
|
||||
### 7. Update `list_all_models` Status Updates (Copy-on-Write)
|
||||
|
||||
**Lines 149-153**:
|
||||
```rust
|
||||
// Before:
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in status_updates {
|
||||
guard.insert(provider_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
// After:
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
let mut new_cache = (**guard).clone();
|
||||
for (provider_id, status) in status_updates {
|
||||
new_cache.insert(provider_id, status);
|
||||
}
|
||||
*guard = Arc::new(new_cache);
|
||||
}
|
||||
```
|
||||
|
||||
### 8. Update `refresh_health` (Avoid Intermediate Vec)
|
||||
|
||||
**Lines 162-184**:
|
||||
```rust
|
||||
// Before:
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
// ...
|
||||
});
|
||||
}
|
||||
|
||||
// After:
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
{
|
||||
let guard = self.providers.read().await;
|
||||
for (provider_id, provider) in guard.iter() {
|
||||
let provider_id = provider_id.clone();
|
||||
let provider = Arc::clone(provider);
|
||||
|
||||
tasks.push(async move {
|
||||
// ...
|
||||
});
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 9. Update `refresh_health` Status Updates (Copy-on-Write)
|
||||
|
||||
**Lines 191-194**:
|
||||
```rust
|
||||
// Before:
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in &updates {
|
||||
guard.insert(provider_id.clone(), *status);
|
||||
}
|
||||
}
|
||||
|
||||
// After:
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
let mut new_cache = (**guard).clone();
|
||||
for (provider_id, status) in &updates {
|
||||
new_cache.insert(provider_id.clone(), *status);
|
||||
}
|
||||
*guard = Arc::new(new_cache);
|
||||
}
|
||||
```
|
||||
|
||||
### 10. Update `provider_statuses()` Return Type
|
||||
|
||||
**Lines 218-221**:
|
||||
```rust
|
||||
// Before:
|
||||
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.clone()
|
||||
}
|
||||
|
||||
// After:
|
||||
/// Snapshot the currently cached statuses.
|
||||
/// Returns an Arc to avoid cloning the entire HashMap on every call.
|
||||
pub async fn provider_statuses(&self) -> Arc<HashMap<String, ProviderStatus>> {
|
||||
let guard = self.status_cache.read().await;
|
||||
Arc::clone(&guard)
|
||||
}
|
||||
```
|
||||
|
||||
**Rationale**: Returns Arc for cheap reference-counted sharing instead of deep clone.
|
||||
|
||||
## Call Site Updates
|
||||
|
||||
### File: `crates/owlen-cli/src/commands/providers.rs`
|
||||
|
||||
**Lines 218-220**:
|
||||
```rust
|
||||
// Before:
|
||||
let statuses = manager.provider_statuses().await;
|
||||
print_models(records, models, statuses);
|
||||
|
||||
// After:
|
||||
let statuses = manager.provider_statuses().await;
|
||||
print_models(records, models, (*statuses).clone());
|
||||
```
|
||||
|
||||
**Rationale**: `print_models` expects owned HashMap. Clone once at call site instead of always cloning in `provider_statuses()`.
|
||||
|
||||
### File: `crates/owlen-tui/src/app/worker.rs`
|
||||
|
||||
**Add import**:
|
||||
```rust
|
||||
use std::collections::HashMap;
|
||||
```
|
||||
|
||||
**Lines 20-52**:
|
||||
```rust
|
||||
// Before:
|
||||
let mut last_statuses = provider_manager.provider_statuses().await;
|
||||
|
||||
loop {
|
||||
// ...
|
||||
let statuses = provider_manager.refresh_health().await;
|
||||
|
||||
for (provider_id, status) in statuses {
|
||||
let changed = match last_statuses.get(&provider_id) {
|
||||
Some(previous) => previous != &status,
|
||||
None => true,
|
||||
};
|
||||
|
||||
last_statuses.insert(provider_id.clone(), status);
|
||||
|
||||
if changed && message_tx.send(/* ... */).is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After:
|
||||
let mut last_statuses: Arc<HashMap<String, ProviderStatus>> =
|
||||
provider_manager.provider_statuses().await;
|
||||
|
||||
loop {
|
||||
// ...
|
||||
let statuses = provider_manager.refresh_health().await;
|
||||
|
||||
for (provider_id, status) in &statuses {
|
||||
let changed = match last_statuses.get(provider_id) {
|
||||
Some(previous) => previous != status,
|
||||
None => true,
|
||||
};
|
||||
|
||||
if changed && message_tx.send(AppMessage::ProviderStatus {
|
||||
provider_id: provider_id.clone(),
|
||||
status: *status,
|
||||
}).is_err() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Update last_statuses after processing all changes
|
||||
last_statuses = Arc::new(statuses);
|
||||
}
|
||||
```
|
||||
|
||||
**Rationale**:
|
||||
- Store Arc instead of owned HashMap
|
||||
- Iterate over references in loop (avoid moving statuses HashMap)
|
||||
- Replace entire Arc after all changes processed
|
||||
- Only clone provider_id when sending message
|
||||
|
||||
## Performance Impact
|
||||
|
||||
**Expected improvements**:
|
||||
- **`list_all_models`**: 15-20% reduction in execution time (eliminates String clone overhead)
|
||||
- **`refresh_health`**: Similar benefits, plus avoids intermediate Vec allocation
|
||||
- **`provider_statuses`**: ~100x faster for typical HashMap sizes (Arc clone vs deep clone)
|
||||
- **Background worker**: Reduced allocations in hot loop (30-second interval)
|
||||
|
||||
**Trade-offs**:
|
||||
- Status updates now require cloning the HashMap (copy-on-write)
|
||||
- However, status updates are infrequent compared to reads
|
||||
- Overall: Optimizes the hot path (reads) at the expense of the cold path (writes)
|
||||
|
||||
## Testing
|
||||
|
||||
Run the following to verify correctness:
|
||||
```bash
|
||||
cargo test -p owlen-core provider
|
||||
cargo test -p owlen-tui
|
||||
cargo test -p owlen-cli
|
||||
```
|
||||
|
||||
All existing tests should pass without modification.
|
||||
|
||||
## Alternative Considered: DashMap
|
||||
|
||||
The report suggested `DashMap` as an alternative for lock-free concurrent reads. However, this was rejected in favor of the simpler Arc-based approach because:
|
||||
|
||||
1. **Simplicity**: Arc<HashMap> + RwLock is easier to understand and maintain
|
||||
2. **Sufficient**: The current read/write pattern doesn't require lock-free data structures
|
||||
3. **Dependency**: Avoids adding another dependency
|
||||
4. **Performance**: Arc cloning is already extremely cheap (atomic increment)
|
||||
|
||||
If profiling shows RwLock contention in the future, DashMap can be reconsidered.
|
||||
|
||||
## Implementation Status
|
||||
|
||||
**Partially Applied**: Due to file watcher conflicts (likely rust-analyzer or rustfmt), the changes were documented here but not all applied to the source files.
|
||||
|
||||
**To complete implementation**:
|
||||
1. Disable file watchers temporarily
|
||||
2. Apply all changes listed above
|
||||
3. Run `cargo fmt` to format the code
|
||||
4. Run tests to verify correctness
|
||||
5. Re-enable file watchers
|
||||
|
||||
## References
|
||||
|
||||
- Project analysis report identifying clone overhead
|
||||
- Rust `Arc` documentation: https://doc.rust-lang.org/std/sync/struct.Arc.html
|
||||
- Copy-on-write pattern in Rust
|
||||
- RwLock best practices
|
||||
234
README.md
234
README.md
@@ -1,234 +0,0 @@
|
||||
# OWLEN
|
||||
|
||||
> Terminal-native assistant for running local language models with a comfortable TUI.
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## What Is OWLEN?
|
||||
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud
|
||||
language models. It provides a responsive chat workflow that now routes through a
|
||||
multi-provider manager—handling local Ollama, Ollama Cloud, and future MCP-backed providers—
|
||||
with a focus on developer productivity, vim-style navigation, and seamless session
|
||||
management—all without leaving your terminal.
|
||||
|
||||
## Alpha Status
|
||||
|
||||
This project is currently in **alpha** and under active development. Core features are functional, but expect occasional bugs and breaking changes. Feedback, bug reports, and contributions are very welcome!
|
||||
|
||||
## Screenshots
|
||||
|
||||

|
||||
|
||||
The refreshed chrome introduces a cockpit-style header with live gradient gauges for context and cloud usage, plus glassy panels that keep vim-inspired navigation easy to follow. See more screenshots in the [`images/`](images/) directory.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vim-style Navigation**: Normal, editing, visual, and command modes.
|
||||
- **Streaming Responses**: Real-time token streaming from Ollama.
|
||||
- **Advanced Text Editing**: Multi-line input, history, and clipboard support.
|
||||
- **Session Management**: Save, load, and manage conversations.
|
||||
- **Code Side Panel**: Switch to code mode (`:mode code`) and open files inline with `:open <path>` for LLM-assisted coding.
|
||||
- **Cockpit Header**: Gradient context and cloud usage bars with live quota bands and provider fallbacks.
|
||||
- **Theming System**: 10 built-in themes and support for custom themes.
|
||||
- **Modular Architecture**: Extensible provider system orchestrated by the new `ProviderManager`, ready for additional MCP-backed providers.
|
||||
- **Dual-Source Model Picker**: Merge local and cloud catalogues with real-time availability badges powered by the background health worker.
|
||||
- **Non-Blocking UI Loop**: Asynchronous generation tasks and provider health checks run off-thread, keeping the TUI responsive even while streaming long replies.
|
||||
- **Guided Setup**: `owlen config doctor` upgrades legacy configs and verifies your environment in seconds.
|
||||
|
||||
## Repository Automation
|
||||
|
||||
Owlen now ships with Git-aware automation helpers so you can review code and stage commits without leaving the terminal:
|
||||
|
||||
- **CLI** – `owlen repo commit-template` renders a conventional commit scaffolding from the staged diff (`--working-tree` inspects unstaged changes), while `owlen repo review` summarises the current branch or a GitHub pull request. Provide `--owner`, `--repo`, and `--number` to fetch remote diffs; the command picks up credentials from `GITHUB_TOKEN` (override with `--token-env` or `--token`).
|
||||
- **TUI** – `:repo template` injects the generated template into the conversation stream, and `:repo review [--base BRANCH] [--head REF]` produces a Markdown review of local changes. The results appear as system messages so you can follow up with an LLM turn or copy them directly into a GitHub comment.
|
||||
- **Automation APIs** – Under the hood, `owlen-core::automation::repo` exposes reusable builders (`RepoAutomation`, `CommitTemplate`, `PullRequestReview`) that mirror the Claude Code workflow style. They provide JSON-serialisable checklists, workflow steps, and heuristics that highlight risky changes (e.g., new `unwrap()` calls, unchecked `unsafe` blocks, or absent tests).
|
||||
|
||||
Add a personal access token with `repo` scope to unlock GitHub diff fetching. Enterprise installations can point at a custom API host with the `--api-endpoint` flag.
|
||||
|
||||
## Upgrading to v0.2
|
||||
|
||||
- **Local + Cloud resiliency**: Owlen now distinguishes the on-device daemon from Ollama Cloud and gracefully falls back to local if the hosted key is missing or unauthorized. Cloud requests include `Authorization: Bearer <API_KEY>` and reuse the canonical `https://ollama.com` base URL so you no longer hit 401 loops.
|
||||
- **Context + quota cockpit**: The header shows `context used / window (percentage)` and a second gauge for hourly/weekly cloud token usage. Configure soft limits via `providers.ollama_cloud.hourly_quota_tokens` and `weekly_quota_tokens`; Owlen tracks consumption locally even when the provider omits token counters.
|
||||
- **Web search tooling**: When cloud is enabled, models can call the spec-compliant `web_search` tool automatically. Toggle availability at runtime with `:web on` / `:web off` if you need a local-only session.
|
||||
- **Docs & config parity**: Ship-ready config templates now include per-provider `list_ttl_secs` and `default_context_window` values, plus explicit `OLLAMA_API_KEY` guidance. Run `owlen config doctor` after upgrading from v0.1 to normalize legacy keys and receive deprecation warnings for `OLLAMA_CLOUD_API_KEY` and `OWLEN_OLLAMA_CLOUD_API_KEY`.
|
||||
- **Runtime toggles**: Use `:web on` / `:web off` in the TUI or `owlen providers web --enable/--disable` from the CLI to expose or hide the `web_search` tool without editing `config.toml`.
|
||||
|
||||
## MCP Naming & Reference Bundles
|
||||
|
||||
Owlen enforces spec-compliant tool identifiers: stick to `^[A-Za-z0-9_-]{1,64}$`, avoid dotted names, and keep identifiers short so the host can qualify them when multiple servers are present.citeturn11search0 Define your tools with underscores or hyphens (for example, `web_search`, `filesystem_read`, `notion_query`) and treat any legacy dotted forms as incompatible.
|
||||
|
||||
Modern MCP hosts converge on a common bundle of connectors that cover three broad categories: local operations (filesystem, terminal, git, structured HTTP fetch, browser automation), compute sandboxes (Python, notebook adapters, sequential-thinking planners, test runners), and SaaS integrations (GitHub issues, Notion workspaces, Slack, Stripe, Sentry, Google Drive, Zapier-style automation, design system search).citeturn12search3turn12search10 Owlen’s configuration examples mirror that baseline so a fresh install can wire up the same capabilities without additional mapping.
|
||||
|
||||
To replicate the reference bundle today:
|
||||
|
||||
1. Enable the built-in tools that ship with Owlen (`web_search`, filesystem resource APIs, execution sandboxes).
|
||||
2. Add external servers under `[mcp_servers]`, keeping names spec-compliant (e.g., `filesystem`, `terminal`, `git`, `browser`, `http_fetch`, `python`, `notebook`, `sequential_thinking`, `sentry`, `notion`, `slack`, `stripe`, `google_drive`, `memory_bank`, `automation_hub`).
|
||||
3. Qualify tool identifiers in prompts and configs using the `{server}__{tool}` pattern once multiple servers contribute overlapping operations (`filesystem__read`, `browser__request`, `notion__query_database`).
|
||||
|
||||
See the updated MCP guide in `docs/` for detailed installation commands, environment variables, and health checks for each connector. The documentation set below walks through configuration and runtime toggles for `web_search` and the rest of the reference bundle.
|
||||
|
||||
## Security & Privacy
|
||||
|
||||
Owlen is designed to keep data local by default while still allowing controlled access to remote tooling.
|
||||
|
||||
- **Local-first execution**: All LLM calls flow through the bundled MCP LLM server which talks to a local Ollama instance. If the server is unreachable, Owlen stays usable in “offline mode” and surfaces clear recovery instructions.
|
||||
- **Sandboxed tooling**: Code execution runs in Docker according to the MCP Code Server settings, and future releases will extend this to other OS-level sandboxes (`sandbox-exec` on macOS, Windows job objects).
|
||||
- **Session storage**: Conversations are stored under the platform data directory and can be encrypted at rest. Set `privacy.encrypt_local_data = true` in `config.toml` to enable AES-GCM storage backed by an Owlen-managed secret key—no passphrase entry required.
|
||||
- **Network access**: No telemetry is sent. The only outbound requests occur when you explicitly enable remote tooling (e.g., web search) or configure a cloud LLM provider. Each tool is opt-in via `privacy` and `tools` configuration sections.
|
||||
- **Config migrations**: Every saved `config.toml` carries a schema version and is upgraded automatically; deprecated keys trigger warnings so security-related settings are not silently ignored.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
- Rust 1.75+ and Cargo.
|
||||
- A running Ollama instance.
|
||||
- A terminal that supports 256 colors.
|
||||
|
||||
### Installation
|
||||
|
||||
Pick the option that matches your platform and appetite for source builds:
|
||||
|
||||
| Platform | Package / Command | Notes |
|
||||
| --- | --- | --- |
|
||||
| Arch Linux | `yay -S owlen-git` | Builds from the latest `dev` branch via AUR. |
|
||||
| Other Linux | `cargo install --path crates/owlen-cli --locked --force` | Requires Rust 1.75+ and a running Ollama daemon. |
|
||||
| macOS | `cargo install --path crates/owlen-cli --locked --force` | macOS 12+ tested. Install Ollama separately (`brew install ollama`). The binary links against the system OpenSSL – ensure Command Line Tools are installed. |
|
||||
| Windows (experimental) | `cargo install --path crates/owlen-cli --locked --force` | Enable the GNU toolchain (`rustup target add x86_64-pc-windows-gnu`) and install Ollama for Windows preview builds. Some optional tools (e.g., Docker-based code execution) are currently disabled. |
|
||||
|
||||
If you prefer containerised builds, use the provided `Dockerfile` as a base image and copy out `target/release/owlen`.
|
||||
|
||||
Run the helper scripts to sanity-check platform coverage:
|
||||
|
||||
```bash
|
||||
# Windows compatibility smoke test (GNU toolchain)
|
||||
scripts/check-windows.sh
|
||||
|
||||
# Reproduce CI packaging locally (choose a target from .woodpecker.yml)
|
||||
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||
```
|
||||
|
||||
> **Tip (macOS):** On the first launch macOS Gatekeeper may quarantine the binary. Clear the attribute (`xattr -d com.apple.quarantine $(which owlen)`) or build from source locally to avoid notarisation prompts.
|
||||
|
||||
### Running OWLEN
|
||||
|
||||
Make sure Ollama is running, then launch the application:
|
||||
```bash
|
||||
owlen
|
||||
```
|
||||
If you built from source without installing, you can run it with:
|
||||
```bash
|
||||
./target/release/owlen
|
||||
```
|
||||
|
||||
### Updating
|
||||
|
||||
Owlen does not auto-update. Run `owlen upgrade` at any time to print the recommended manual steps (pull the repository and reinstall with `cargo install --path crates/owlen-cli --force`). Arch Linux users can update via the `owlen-git` AUR package.
|
||||
|
||||
## Using the TUI
|
||||
|
||||
OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode) or `?` in Normal mode to view the help screen with all keybindings.
|
||||
|
||||
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:w`, `:session save`, `:theme`.
|
||||
- **Quick Exit**: Press `Ctrl+C` twice in Normal mode to quit quickly (first press still cancels active generations).
|
||||
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||
|
||||
### Keymaps
|
||||
|
||||
Two built-in keymaps ship with Owlen:
|
||||
|
||||
- `vim` (default) – the existing modal bindings documented above.
|
||||
- `emacs` – bindings centred around `Alt+X`, `Ctrl+Space`, and `Alt+O` shortcuts with Emacs-style submit (`Ctrl+Enter`).
|
||||
|
||||
Switch at runtime with `:keymap vim` or `:keymap emacs`. Persist your choice by setting `ui.keymap_profile = "emacs"` (or `"vim"`) in `config.toml`. If you prefer a fully custom layout, point `ui.keymap_path` at a TOML file using the same format as [`crates/owlen-tui/keymap.toml`](crates/owlen-tui/keymap.toml); the new emacs profile file [`crates/owlen-tui/keymap_emacs.toml`](crates/owlen-tui/keymap_emacs.toml) is a useful template.
|
||||
|
||||
Model discovery commands worth remembering:
|
||||
|
||||
- `:models --local` or `:models --cloud` jump directly to the corresponding section in the picker.
|
||||
- `:cloud setup [--force-cloud-base-url]` stores your cloud API key without clobbering an existing local base URL (unless you opt in with the flag).
|
||||
- `:limits` prints the locally tracked hourly/weekly token totals for each provider and mirrors the values shown in the chat header.
|
||||
When a catalogue is unreachable, Owlen now tags the picker with `Local unavailable` / `Cloud unavailable` so you can recover without guessing.
|
||||
|
||||
## Documentation
|
||||
|
||||
For more detailed information, please refer to the following documents:
|
||||
|
||||
- **[CONTRIBUTING.md](CONTRIBUTING.md)**: Guidelines for contributing to the project.
|
||||
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
||||
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
||||
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
||||
- **[docs/repo-map.md](docs/repo-map.md)**: Snapshot of the workspace layout and key crates.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: Trait-level details for implementing providers.
|
||||
- **[docs/adding-providers.md](docs/adding-providers.md)**: Step-by-step checklist for wiring a provider into the multi-provider architecture and test suite.
|
||||
- **[docs/tui-ux-playbook.md](docs/tui-ux-playbook.md)**: Design principles, modal ergonomics, and keybinding guidance for the TUI.
|
||||
- **Experimental providers staging area**: [crates/providers/experimental/README.md](crates/providers/experimental/README.md) records the placeholder crates (OpenAI, Anthropic, Gemini) and their current status.
|
||||
- **[docs/platform-support.md](docs/platform-support.md)**: Current OS support matrix and cross-check instructions.
|
||||
|
||||
## Developer Tasks
|
||||
|
||||
- `cargo xtask screenshots` regenerates deterministic ANSI dumps (and, when
|
||||
`chafa` is available, PNG renders) for the documentation gallery. Use
|
||||
`--no-png` to skip the PNG step or `--output <dir>` to redirect the output.
|
||||
|
||||
## Conversation Compression
|
||||
|
||||
Owlen automatically compacts older turns once a chat crosses the configured
|
||||
token threshold. The behaviour is controlled by the `[chat]` section in
|
||||
`config.toml` (enabled by default via `chat.auto_compress = true`).
|
||||
|
||||
- Launch the TUI with `--no-auto-compress` to opt out for a single run.
|
||||
- Inside the app, `:compress now` generates an on-demand summary, while
|
||||
`:compress auto on|off` flips the automatic mode and persists the change.
|
||||
- Each compression pass emits a system summary that carries metadata about the
|
||||
retained messages, strategy, and estimated token savings.
|
||||
|
||||
## Configuration
|
||||
|
||||
OWLEN stores its configuration in the standard platform-specific config directory:
|
||||
|
||||
| Platform | Location |
|
||||
|----------|----------|
|
||||
| Linux | `~/.config/owlen/config.toml` |
|
||||
| macOS | `~/Library/Application Support/owlen/config.toml` |
|
||||
| Windows | `%APPDATA%\owlen\config.toml` |
|
||||
|
||||
Use `owlen config init` to scaffold a fresh configuration (pass `--force` to overwrite an existing file), `owlen config path` to print the resolved location, and `owlen config doctor` to migrate legacy layouts automatically.
|
||||
You can also add custom themes alongside the config directory (e.g., `~/.config/owlen/themes/`).
|
||||
|
||||
See the [themes/README.md](themes/README.md) for more details on theming.
|
||||
|
||||
## Testing
|
||||
|
||||
Owlen uses standard Rust tooling for verification. Run the full test suite with:
|
||||
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
Unit tests cover the command palette state machine, agent response parsing, and key MCP abstractions. Formatting and lint checks can be run with `cargo fmt --all` and `cargo clippy` respectively.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Upcoming milestones focus on feature parity with modern code assistants while keeping Owlen local-first:
|
||||
|
||||
1. **Phase 11 – MCP client enhancements**: `owlen mcp add/list/remove`, resource references (`@github:issue://123`), and MCP prompt slash commands.
|
||||
2. **Phase 12 – Approval & sandboxing**: Three-tier approval modes plus platform-specific sandboxes (Docker, `sandbox-exec`, Windows job objects).
|
||||
3. **Phase 13 – Project documentation system**: Automatic `OWLEN.md` generation, contextual updates, and nested project support.
|
||||
4. **Phase 15 – Provider expansion**: OpenAI, Anthropic, and other cloud providers layered onto the existing Ollama-first architecture.
|
||||
|
||||
See `AGENTS.md` for the long-form roadmap and design notes.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are highly welcome! Please see our **[Contributing Guide](CONTRIBUTING.md)** for details on how to get started, including our code style, commit conventions, and pull request process.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the GNU Affero General Public License v3.0. See the [LICENSE](LICENSE) file for details.
|
||||
For commercial or proprietary integrations that cannot adopt AGPL, please reach out to the maintainers to discuss alternative licensing arrangements.
|
||||
40
SECURITY.md
40
SECURITY.md
@@ -1,40 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
We are currently in a pre-release phase, so only the latest version is actively supported. As we move towards a 1.0 release, this policy will be updated with specific version support.
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| < 1.0 | :white_check_mark: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
The Owlen team and community take all security vulnerabilities seriously. Thank you for improving the security of our project. We appreciate your efforts and responsible disclosure and will make every effort to acknowledge your contributions.
|
||||
|
||||
To report a security vulnerability, please email the project lead at [security@owlibou.com](mailto:security@owlibou.com) with a detailed description of the issue, the steps to reproduce it, and any affected versions.
|
||||
|
||||
You will receive a response from us within 48 hours. If the issue is confirmed, we will release a patch as soon as possible, depending on the complexity of the issue.
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues.
|
||||
|
||||
## Design Overview
|
||||
|
||||
Owlen ships with a local-first architecture:
|
||||
|
||||
- **Process isolation** – The TUI speaks to language models through a separate MCP LLM server. Tool execution (code, web, filesystem) occurs in dedicated MCP processes so a crash or hang cannot take down the UI.
|
||||
- **Sandboxing** – The MCP Code Server executes snippets in Docker containers. Upcoming releases will extend this to platform sandboxes (`sandbox-exec` on macOS, Windows job objects) as described in our roadmap.
|
||||
- **Network posture** – No telemetry is emitted. The application only reaches the network when a user explicitly enables remote tools (web search, remote MCP servers) or configures cloud providers. All tools require allow-listing in `config.toml`.
|
||||
|
||||
## Data Handling
|
||||
|
||||
- **Sessions** – Conversations are stored in the user’s data directory (`~/.local/share/owlen` on Linux, equivalent paths on macOS/Windows). Enable `privacy.encrypt_local_data = true` to wrap the session store in AES-GCM encryption using an Owlen-managed key—no interactive passphrase prompts are required.
|
||||
- **Credentials** – API tokens are resolved from the config file or environment variables at runtime and are never written to logs.
|
||||
- **Remote calls** – When remote search or cloud LLM tooling is on, only the minimum payload (prompt, tool arguments) is sent. All outbound requests go through the MCP servers so they can be audited or disabled centrally.
|
||||
|
||||
## Supply-Chain Safeguards
|
||||
|
||||
- The repository includes a git `pre-commit` configuration that runs `cargo fmt`, `cargo check`, and `cargo clippy -- -D warnings` on every commit.
|
||||
- Pull requests generated with the assistance of AI tooling must receive manual maintainer review before merging. Contributors are asked to declare AI involvement in their PR description so maintainers can double-check the changes.
|
||||
|
||||
Additional recommendations for operators (e.g., running Owlen on shared systems) are maintained in `docs/security.md` (planned) and the issue tracker.
|
||||
@@ -1,197 +0,0 @@
|
||||
# SQLx 0.7 to 0.8 Migration Guide for Owlen
|
||||
|
||||
## Executive Summary
|
||||
|
||||
The Owlen project has been successfully upgraded from SQLx 0.7 to SQLx 0.8. The migration was straightforward as Owlen uses SQLite, which is not affected by the security vulnerability CVE-2024-0363.
|
||||
|
||||
## Key Changes Made
|
||||
|
||||
### 1. Cargo.toml Update
|
||||
|
||||
**Before (SQLx 0.7):**
|
||||
```toml
|
||||
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
||||
```
|
||||
|
||||
**After (SQLx 0.8):**
|
||||
```toml
|
||||
sqlx = { version = "0.8", default-features = false, features = ["runtime-tokio", "tls-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
||||
```
|
||||
|
||||
**Key change:** Split `runtime-tokio-rustls` into `runtime-tokio` and `tls-rustls`
|
||||
|
||||
## Important Notes for Owlen
|
||||
|
||||
### 1. Security Status
|
||||
|
||||
- **CVE-2024-0363 (Binary Protocol Misinterpretation)**: This vulnerability **DOES NOT AFFECT SQLite users**
|
||||
- Only affects PostgreSQL and MySQL that use binary network protocols
|
||||
- SQLite uses an in-process C API, not a network protocol
|
||||
- No security risk for Owlen's SQLite implementation
|
||||
|
||||
### 2. Date/Time Handling
|
||||
|
||||
Owlen uses `chrono` types directly, not through SQLx's query macros for datetime columns. The current implementation:
|
||||
- Uses `INTEGER` columns for timestamps (Unix epoch seconds)
|
||||
- Converts between `SystemTime` and epoch seconds manually
|
||||
- No changes needed for datetime handling
|
||||
|
||||
### 3. Database Schema
|
||||
|
||||
The existing migrations work without modification:
|
||||
- `/crates/owlen-core/migrations/0001_create_conversations.sql`
|
||||
- `/crates/owlen-core/migrations/0002_create_secure_items.sql`
|
||||
|
||||
### 4. Offline Mode Changes
|
||||
|
||||
For CI/CD pipelines:
|
||||
- Offline mode is now always enabled (no separate flag needed)
|
||||
- Use `SQLX_OFFLINE=true` environment variable to force offline builds
|
||||
- Run `cargo sqlx prepare --workspace` to regenerate query metadata
|
||||
- The `.sqlx` directory should be committed to version control
|
||||
|
||||
## Testing Checklist
|
||||
|
||||
After the upgrade, perform these tests:
|
||||
|
||||
- [ ] Run all unit tests: `cargo test --all`
|
||||
- [ ] Test database operations:
|
||||
- [ ] Create new conversation
|
||||
- [ ] Save existing conversation
|
||||
- [ ] Load conversation by ID
|
||||
- [ ] List all conversations
|
||||
- [ ] Search conversations
|
||||
- [ ] Delete conversation
|
||||
- [ ] Test migrations: `cargo sqlx migrate run`
|
||||
- [ ] Test offline compilation (CI simulation):
|
||||
```bash
|
||||
rm -rf .sqlx
|
||||
cargo sqlx prepare --workspace
|
||||
SQLX_OFFLINE=true cargo build --release
|
||||
```
|
||||
|
||||
## Migration Code Patterns
|
||||
|
||||
### Connection Pool Setup (No Changes Required)
|
||||
|
||||
The connection pool setup remains identical:
|
||||
|
||||
```rust
|
||||
use sqlx::sqlite::{SqlitePool, SqlitePoolOptions, SqliteConnectOptions};
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!("sqlite://{}", path))?
|
||||
.create_if_missing(true)
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.synchronous(SqliteSynchronous::Normal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect_with(options)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### Query Execution (No Changes Required)
|
||||
|
||||
Standard queries work the same:
|
||||
|
||||
```rust
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO conversations (id, name, description, model, message_count, created_at, updated_at, data)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
description = excluded.description,
|
||||
model = excluded.model,
|
||||
message_count = excluded.message_count,
|
||||
updated_at = excluded.updated_at,
|
||||
data = excluded.data
|
||||
"#
|
||||
)
|
||||
.bind(&id)
|
||||
.bind(&name)
|
||||
.bind(&description)
|
||||
.bind(&model)
|
||||
.bind(message_count)
|
||||
.bind(created_at)
|
||||
.bind(updated_at)
|
||||
.bind(&data)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### Transaction Handling (No Changes Required)
|
||||
|
||||
```rust
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
sqlx::query("INSERT INTO users (name) VALUES (?)")
|
||||
.bind("Alice")
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
```
|
||||
|
||||
## Performance Improvements in 0.8
|
||||
|
||||
1. **SQLite-specific fixes**: Version 0.8.6 fixed a performance regression for SQLite
|
||||
2. **Better connection pooling**: More efficient connection reuse
|
||||
3. **Improved compile-time checking**: Faster query validation
|
||||
|
||||
## Common Pitfalls to Avoid
|
||||
|
||||
1. **Feature flag splitting**: Don't forget to split `runtime-tokio-rustls` into two separate features
|
||||
2. **Dependency conflicts**: Check for `libsqlite3-sys` version conflicts with `cargo tree -i libsqlite3-sys`
|
||||
3. **Offline mode**: Remember that offline mode is always on - no need to enable it separately
|
||||
|
||||
## Future Considerations
|
||||
|
||||
### If Moving to query! Macro
|
||||
|
||||
If you decide to use compile-time checked queries in the future:
|
||||
|
||||
```rust
|
||||
// Instead of manual query building
|
||||
let row = sqlx::query("SELECT * FROM conversations WHERE id = ?")
|
||||
.bind(&id)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
|
||||
// Use compile-time checked queries
|
||||
let conversation = sqlx::query_as!(
|
||||
ConversationRow,
|
||||
"SELECT * FROM conversations WHERE id = ?",
|
||||
id
|
||||
)
|
||||
.fetch_one(&pool)
|
||||
.await?;
|
||||
```
|
||||
|
||||
### If Adding DateTime Columns
|
||||
|
||||
If you add proper DATETIME columns in the future (instead of INTEGER timestamps):
|
||||
|
||||
```rust
|
||||
// With SQLx 0.8 + chrono feature, you'll use time crate types
|
||||
use time::PrimitiveDateTime;
|
||||
|
||||
// Instead of chrono::NaiveDateTime
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct MyModel {
|
||||
created_at: PrimitiveDateTime, // Not chrono::NaiveDateTime
|
||||
}
|
||||
```
|
||||
|
||||
## Verification Steps
|
||||
|
||||
1. **Build successful**: ✅ SQLx 0.8 compiles without errors
|
||||
2. **Tests pass**: Run `cargo test -p owlen-core` to verify
|
||||
3. **Migrations work**: Run `cargo sqlx migrate info` to check migration status
|
||||
4. **Runtime works**: Start the application and perform basic operations
|
||||
|
||||
## Resources
|
||||
|
||||
- [SQLx 0.8 Release Notes](https://github.com/launchbadge/sqlx/releases/tag/v0.8.0)
|
||||
- [SQLx Migration Guide](https://github.com/launchbadge/sqlx/blob/main/CHANGELOG.md)
|
||||
- [CVE-2024-0363 Details](https://rustsec.org/advisories/RUSTSEC-2024-0363)
|
||||
10
agents.md
10
agents.md
@@ -1,10 +0,0 @@
|
||||
# Agents Upgrade Plan
|
||||
|
||||
- [x] feat: support multimodal inputs (images, rich artifacts) and preview panes so non-text context matches Codex CLI image handling and Claude Code’s artifact outputs
|
||||
- [x] feat: integrate repository automation (GitHub PR review, commit templating, Claude SDK-style automation APIs) to reach parity with Codex CLI’s GitHub integration and Claude Code’s CLI/SDK automation
|
||||
- feat: implement Codex-style non-blocking TUI so commands remain usable while backend work runs:
|
||||
1. Add an `AppEvent` channel and dispatch layer in `crates/owlen-tui/src/app/mod.rs` that mirrors the `tokio::select!` loop used in `codex-rs/tui/src/app.rs:190-197` to multiplex UI input, session events, and background updates without blocking redraws.
|
||||
2. Refactor `ChatApp::process_pending_llm_request` and related helpers to spawn tasks that submit prompts via `SessionController` and stream results back through the new channel, following `codex-rs/tui/src/chatwidget/agent.rs:16-61` so the request lifecycle no longer stalls the UI thread.
|
||||
3. Track active-turn state plus queued inputs inside `ChatApp` and surface them through the status pane—similar to `codex-rs/tui/src/chatwidget.rs:1105-1132` and `codex-rs/tui/src/bottom_pane/mod.rs:334-352,378-383`—so users can enqueue commands/slash actions while a turn is executing.
|
||||
4. Introduce a frame requester/draw scheduler (accessible from `ChatApp` and background tasks) that coalesces redraws like `codex-rs/tui/src/tui.rs:234-390`, ensuring notifications, queue updates, and streaming deltas trigger renders without blocking the event loop.
|
||||
5. Extend input handling and regression tests to cover concurrent queued messages, cancellation, and post-turn flushing, echoing the completion hooks in `codex-rs/tui/src/chatwidget.rs:436-455` and keeping `/help` and command palette responsive under load.
|
||||
35
config.toml
35
config.toml
@@ -1,35 +0,0 @@
|
||||
[general]
|
||||
default_provider = "ollama_local"
|
||||
default_model = "llama3.2:latest"
|
||||
|
||||
[privacy]
|
||||
encrypt_local_data = true
|
||||
|
||||
[providers.ollama_local]
|
||||
enabled = true
|
||||
provider_type = "ollama"
|
||||
base_url = "http://localhost:11434"
|
||||
list_ttl_secs = 60
|
||||
default_context_window = 8192
|
||||
|
||||
[providers.ollama_cloud]
|
||||
enabled = false
|
||||
provider_type = "ollama_cloud"
|
||||
base_url = "https://ollama.com"
|
||||
api_key_env = "OLLAMA_API_KEY"
|
||||
hourly_quota_tokens = 50000
|
||||
weekly_quota_tokens = 250000
|
||||
list_ttl_secs = 60
|
||||
default_context_window = 8192
|
||||
|
||||
[providers.openai]
|
||||
enabled = false
|
||||
provider_type = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
|
||||
[providers.anthropic]
|
||||
enabled = false
|
||||
provider_type = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
@@ -1,12 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "Dedicated MCP client library for Owlen, exposing remote MCP server communication"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@@ -1,17 +0,0 @@
|
||||
//! Owlen MCP client library.
|
||||
//!
|
||||
//! This crate provides a thin façade over the remote MCP client implementation
|
||||
//! inside `owlen-core`. It re‑exports the most useful types so downstream
|
||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||
//! core crate internals.
|
||||
|
||||
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
|
||||
// Re‑export the core Provider trait so that the MCP client can also be used as an LLM provider.
|
||||
pub use owlen_core::Provider as McpProvider;
|
||||
|
||||
// Note: The `RemoteMcpClient` type provides its own `new` constructor in the core
|
||||
// crate. Users can call `RemoteMcpClient::new()` directly. No additional wrapper
|
||||
// is needed here.
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-code-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "MCP server exposing safe code execution tools for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
bollard = "0.17"
|
||||
tempfile = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "owlen_mcp_code_server"
|
||||
path = "src/lib.rs"
|
||||
@@ -1,186 +0,0 @@
|
||||
//! MCP server exposing code execution tools with Docker sandboxing.
|
||||
//!
|
||||
//! This server provides:
|
||||
//! - compile_project: Build projects (Rust, Node.js, Python)
|
||||
//! - run_tests: Execute test suites
|
||||
//! - format_code: Run code formatters
|
||||
//! - lint_code: Run linters
|
||||
|
||||
pub mod sandbox;
|
||||
pub mod tools;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
use tools::{CompileProjectTool, FormatCodeTool, LintCodeTool, RunTestsTool};
|
||||
|
||||
/// Tool registry for the code server
|
||||
#[allow(dead_code)]
|
||||
struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn Tool + Send + Sync>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ToolRegistry {
|
||||
fn new() -> Self {
|
||||
let mut tools: HashMap<String, Box<dyn Tool + Send + Sync>> = HashMap::new();
|
||||
tools.insert(
|
||||
"compile_project".to_string(),
|
||||
Box::new(CompileProjectTool::new()),
|
||||
);
|
||||
tools.insert("run_tests".to_string(), Box::new(RunTestsTool::new()));
|
||||
tools.insert("format_code".to_string(), Box::new(FormatCodeTool::new()));
|
||||
tools.insert("lint_code".to_string(), Box::new(LintCodeTool::new()));
|
||||
Self { tools }
|
||||
}
|
||||
|
||||
fn list_tools(&self) -> Vec<owlen_core::mcp::McpToolDescriptor> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|tool| owlen_core::mcp::McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.schema(),
|
||||
requires_network: tool.requires_network(),
|
||||
requires_filesystem: tool.requires_filesystem(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn execute(&self, name: &str, args: Value) -> Result<ToolResult, String> {
|
||||
self.tools
|
||||
.get(name)
|
||||
.ok_or_else(|| format!("Tool not found: {}", name))?
|
||||
.execute(args)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
let registry = Arc::new(ToolRegistry::new());
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = handle_request(req.clone(), registry.clone()).await;
|
||||
match resp {
|
||||
Ok(r) => {
|
||||
let s = serde_json::to_string(&r)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn handle_request(
|
||||
req: RpcRequest,
|
||||
registry: Arc<ToolRegistry>,
|
||||
) -> Result<RpcResponse, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params: InitializeParams =
|
||||
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-code-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = registry.list_tools();
|
||||
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
let call = serde_json::from_value::<owlen_core::mcp::McpToolCall>(
|
||||
req.params.unwrap_or_else(|| json!({})),
|
||||
)
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||
|
||||
let result: ToolResult = registry
|
||||
.execute(&call.name, call.arguments)
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Tool execution failed: {}", e)))?;
|
||||
|
||||
let resp = owlen_core::mcp::McpToolResponse {
|
||||
name: call.name,
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
metadata: result.metadata,
|
||||
duration_ms: result.duration.as_millis() as u128,
|
||||
};
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
//! Docker-based sandboxing for secure code execution
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use bollard::Docker;
|
||||
use bollard::container::{
|
||||
Config, CreateContainerOptions, RemoveContainerOptions, StartContainerOptions,
|
||||
WaitContainerOptions,
|
||||
};
|
||||
use bollard::models::{HostConfig, Mount, MountTypeEnum};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// Result of executing code in a sandbox
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExecutionResult {
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub exit_code: i64,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
/// Docker-based sandbox executor
|
||||
pub struct Sandbox {
|
||||
docker: Docker,
|
||||
memory_limit: i64,
|
||||
cpu_quota: i64,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Sandbox {
|
||||
/// Create a new sandbox with default resource limits
|
||||
pub fn new() -> Result<Self> {
|
||||
let docker =
|
||||
Docker::connect_with_local_defaults().context("Failed to connect to Docker daemon")?;
|
||||
|
||||
Ok(Self {
|
||||
docker,
|
||||
memory_limit: 512 * 1024 * 1024, // 512MB
|
||||
cpu_quota: 50000, // 50% of one core
|
||||
timeout_secs: 30,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a command in a sandboxed container
|
||||
pub async fn execute(
|
||||
&self,
|
||||
image: &str,
|
||||
cmd: &[&str],
|
||||
workspace: Option<&Path>,
|
||||
env: HashMap<String, String>,
|
||||
) -> Result<ExecutionResult> {
|
||||
let container_name = format!("owlen-sandbox-{}", uuid::Uuid::new_v4());
|
||||
|
||||
// Prepare volume mount if workspace provided
|
||||
let mounts = if let Some(ws) = workspace {
|
||||
vec![Mount {
|
||||
target: Some("/workspace".to_string()),
|
||||
source: Some(ws.to_string_lossy().to_string()),
|
||||
typ: Some(MountTypeEnum::BIND),
|
||||
read_only: Some(false),
|
||||
..Default::default()
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Create container config
|
||||
let host_config = HostConfig {
|
||||
memory: Some(self.memory_limit),
|
||||
cpu_quota: Some(self.cpu_quota),
|
||||
network_mode: Some("none".to_string()), // No network access
|
||||
mounts: Some(mounts),
|
||||
auto_remove: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
image: Some(image.to_string()),
|
||||
cmd: Some(cmd.iter().map(|s| s.to_string()).collect()),
|
||||
working_dir: Some("/workspace".to_string()),
|
||||
env: Some(env.iter().map(|(k, v)| format!("{}={}", k, v)).collect()),
|
||||
host_config: Some(host_config),
|
||||
attach_stdout: Some(true),
|
||||
attach_stderr: Some(true),
|
||||
tty: Some(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create container
|
||||
let container = self
|
||||
.docker
|
||||
.create_container(
|
||||
Some(CreateContainerOptions {
|
||||
name: container_name.clone(),
|
||||
..Default::default()
|
||||
}),
|
||||
config,
|
||||
)
|
||||
.await
|
||||
.context("Failed to create container")?;
|
||||
|
||||
// Start container
|
||||
self.docker
|
||||
.start_container(&container.id, None::<StartContainerOptions<String>>)
|
||||
.await
|
||||
.context("Failed to start container")?;
|
||||
|
||||
// Wait for container with timeout
|
||||
let wait_result =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(self.timeout_secs), async {
|
||||
let mut wait_stream = self
|
||||
.docker
|
||||
.wait_container(&container.id, None::<WaitContainerOptions<String>>);
|
||||
|
||||
use futures::StreamExt;
|
||||
if let Some(result) = wait_stream.next().await {
|
||||
result
|
||||
} else {
|
||||
Err(bollard::errors::Error::IOError {
|
||||
err: std::io::Error::other("Container wait stream ended unexpectedly"),
|
||||
})
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let (exit_code, timed_out) = match wait_result {
|
||||
Ok(Ok(result)) => (result.status_code, false),
|
||||
Ok(Err(e)) => {
|
||||
eprintln!("Container wait error: {}", e);
|
||||
(1, false)
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout - kill the container
|
||||
let _ = self
|
||||
.docker
|
||||
.kill_container(
|
||||
&container.id,
|
||||
None::<bollard::container::KillContainerOptions<String>>,
|
||||
)
|
||||
.await;
|
||||
(124, true)
|
||||
}
|
||||
};
|
||||
|
||||
// Get logs
|
||||
let logs = self.docker.logs(
|
||||
&container.id,
|
||||
Some(bollard::container::LogsOptions::<String> {
|
||||
stdout: true,
|
||||
stderr: true,
|
||||
..Default::default()
|
||||
}),
|
||||
);
|
||||
|
||||
use futures::StreamExt;
|
||||
let mut stdout = String::new();
|
||||
let mut stderr = String::new();
|
||||
|
||||
let log_result = tokio::time::timeout(std::time::Duration::from_secs(5), async {
|
||||
let mut logs = logs;
|
||||
while let Some(log) = logs.next().await {
|
||||
match log {
|
||||
Ok(bollard::container::LogOutput::StdOut { message }) => {
|
||||
stdout.push_str(&String::from_utf8_lossy(&message));
|
||||
}
|
||||
Ok(bollard::container::LogOutput::StdErr { message }) => {
|
||||
stderr.push_str(&String::from_utf8_lossy(&message));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
if log_result.is_err() {
|
||||
eprintln!("Timeout reading container logs");
|
||||
}
|
||||
|
||||
// Remove container (auto_remove should handle this, but be explicit)
|
||||
let _ = self
|
||||
.docker
|
||||
.remove_container(
|
||||
&container.id,
|
||||
Some(RemoveContainerOptions {
|
||||
force: true,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(ExecutionResult {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
timed_out,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute in a Rust environment
|
||||
pub async fn execute_rust(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("rust:1.75-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute in a Python environment
|
||||
pub async fn execute_python(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("python:3.11-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute in a Node.js environment
|
||||
pub async fn execute_node(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("node:20-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Sandbox {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create default sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires Docker daemon
|
||||
async fn test_sandbox_rust_compile() {
|
||||
let sandbox = Sandbox::new().unwrap();
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
// Create a simple Rust project
|
||||
std::fs::write(
|
||||
temp_dir.path().join("main.rs"),
|
||||
"fn main() { println!(\"Hello from sandbox!\"); }",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = sandbox
|
||||
.execute_rust(temp_dir.path(), &["rustc", "main.rs"])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert!(!result.timed_out);
|
||||
}
|
||||
}
|
||||
@@ -1,417 +0,0 @@
|
||||
//! Code execution tools using Docker sandboxing
|
||||
|
||||
use crate::sandbox::Sandbox;
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::Result;
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Tool for compiling projects (Rust, Node.js, Python)
|
||||
pub struct CompileProjectTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for CompileProjectTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileProjectTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CompileProjectTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"compile_project"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Compile a project (Rust, Node.js, Python). Detects project type automatically."
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"project_type": {
|
||||
"type": "string",
|
||||
"enum": ["rust", "node", "python"],
|
||||
"description": "Project type (auto-detected if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
// Detect project type
|
||||
let project_type = if let Some(pt) = args.get("project_type").and_then(|v| v.as_str()) {
|
||||
pt.to_string()
|
||||
} else if path.join("Cargo.toml").exists() {
|
||||
"rust".to_string()
|
||||
} else if path.join("package.json").exists() {
|
||||
"node".to_string()
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
"python".to_string()
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
// Execute compilation
|
||||
let result = match project_type.as_str() {
|
||||
"rust" => self.sandbox.execute_rust(&path, &["cargo", "build"]).await,
|
||||
"node" => {
|
||||
self.sandbox
|
||||
.execute_node(&path, &["npm", "run", "build"])
|
||||
.await
|
||||
}
|
||||
"python" => {
|
||||
// Python typically doesn't need compilation, but we can check syntax
|
||||
self.sandbox
|
||||
.execute_python(&path, &["python", "-m", "compileall", "."])
|
||||
.await
|
||||
}
|
||||
_ => return Ok(ToolResult::error("Unsupported project type")),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => {
|
||||
if exec_result.timed_out {
|
||||
Ok(ToolResult::error("Compilation timed out"))
|
||||
} else if exec_result.exit_code == 0 {
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": true,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"project_type": project_type
|
||||
})))
|
||||
} else {
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": false,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"project_type": project_type
|
||||
})))
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(ToolResult::error(&format!("Compilation failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for running test suites
|
||||
pub struct RunTestsTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for RunTestsTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RunTestsTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for RunTestsTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"run_tests"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Run tests for a project (Rust, Node.js, Python)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"test_filter": {
|
||||
"type": "string",
|
||||
"description": "Optional test filter/pattern"
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let test_filter = args.get("test_filter").and_then(|v| v.as_str());
|
||||
|
||||
// Detect project type and run tests
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if let Some(filter) = test_filter {
|
||||
vec!["cargo", "test", filter]
|
||||
} else {
|
||||
vec!["cargo", "test"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
self.sandbox.execute_node(&path, &["npm", "test"]).await
|
||||
} else if path.join("pytest.ini").exists()
|
||||
|| path.join("setup.py").exists()
|
||||
|| path.join("pyproject.toml").exists()
|
||||
{
|
||||
let cmd = if let Some(filter) = test_filter {
|
||||
vec!["pytest", "-k", filter]
|
||||
} else {
|
||||
vec!["pytest"]
|
||||
};
|
||||
self.sandbox.execute_python(&path, &cmd).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect test framework"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||
"success": exec_result.exit_code == 0 && !exec_result.timed_out,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"timed_out": exec_result.timed_out
|
||||
}))),
|
||||
Err(e) => Ok(ToolResult::error(&format!("Tests failed to run: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for formatting code
|
||||
pub struct FormatCodeTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for FormatCodeTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatCodeTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FormatCodeTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"format_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Format code using project-appropriate formatter (rustfmt, prettier, black)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"check_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only check formatting without modifying files",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let check_only = args
|
||||
.get("check_only")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect project type and run formatter
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["cargo", "fmt", "--", "--check"]
|
||||
} else {
|
||||
vec!["cargo", "fmt"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["npx", "prettier", "--check", "."]
|
||||
} else {
|
||||
vec!["npx", "prettier", "--write", "."]
|
||||
};
|
||||
self.sandbox.execute_node(&path, &cmd).await
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["black", "--check", "."]
|
||||
} else {
|
||||
vec!["black", "."]
|
||||
};
|
||||
self.sandbox.execute_python(&path, &cmd).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||
"success": exec_result.exit_code == 0,
|
||||
"formatted": !check_only && exec_result.exit_code == 0,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr
|
||||
}))),
|
||||
Err(e) => Ok(ToolResult::error(&format!("Formatting failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for linting code
|
||||
pub struct LintCodeTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for LintCodeTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl LintCodeTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for LintCodeTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"lint_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Lint code using project-appropriate linter (clippy, eslint, pylint)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"fix": {
|
||||
"type": "boolean",
|
||||
"description": "Automatically fix issues if possible",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let fix = args.get("fix").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
|
||||
// Detect project type and run linter
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if fix {
|
||||
vec!["cargo", "clippy", "--fix", "--allow-dirty"]
|
||||
} else {
|
||||
vec!["cargo", "clippy"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
let cmd = if fix {
|
||||
vec!["npx", "eslint", ".", "--fix"]
|
||||
} else {
|
||||
vec!["npx", "eslint", "."]
|
||||
};
|
||||
self.sandbox.execute_node(&path, &cmd).await
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
// pylint doesn't have auto-fix
|
||||
self.sandbox.execute_python(&path, &["pylint", "."]).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => {
|
||||
let issues_found = exec_result.exit_code != 0;
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": true,
|
||||
"issues_found": issues_found,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr
|
||||
})))
|
||||
}
|
||||
Err(e) => Ok(ToolResult::error(&format!("Linting failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-llm-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-mcp-llm-server"
|
||||
path = "src/main.rs"
|
||||
@@ -1,598 +0,0 @@
|
||||
#![allow(
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
clippy::unnecessary_cast,
|
||||
clippy::manual_flatten,
|
||||
clippy::empty_line_after_outer_attr
|
||||
)]
|
||||
|
||||
use owlen_core::Provider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{Config as OwlenConfig, ensure_provider_config};
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo,
|
||||
methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::types::{ChatParameters, ChatRequest, Message};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
// Suppress warnings are handled by the crate-level attribute at the top.
|
||||
|
||||
/// Arguments for the generate_text tool
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GenerateTextArgs {
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// Simple tool descriptor for generate_text
|
||||
fn generate_text_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "generate_text".to_string(),
|
||||
description: "Generate text using Ollama LLM. Each message must have 'role' (user/assistant/system) and 'content' (string) fields.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"enum": ["user", "assistant", "system"],
|
||||
"description": "The role of the message sender"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content"
|
||||
}
|
||||
},
|
||||
"required": ["role", "content"]
|
||||
},
|
||||
"description": "Array of message objects with role and content"
|
||||
},
|
||||
"temperature": {"type": ["number", "null"], "description": "Sampling temperature (0.0-2.0)"},
|
||||
"max_tokens": {"type": ["integer", "null"], "description": "Maximum tokens to generate"},
|
||||
"model": {"type": "string", "description": "Model name (e.g., llama3.2:latest)"},
|
||||
"stream": {"type": "boolean", "description": "Whether to stream the response"}
|
||||
},
|
||||
"required": ["messages", "model", "stream"]
|
||||
}),
|
||||
requires_network: true,
|
||||
requires_filesystem: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool descriptor for resources/get (read file)
|
||||
fn resources_get_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "resources_get".to_string(),
|
||||
description: "Read and return the TEXT CONTENTS of a single FILE. Use this to read the contents of code files, config files, or text documents. Do NOT use for directories.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the FILE (not directory) to read"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec!["read".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool descriptor for resources/list (list directory)
|
||||
fn resources_list_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "resources_list".to_string(),
|
||||
description: "List the NAMES of all files and directories in a directory. Use this to see what files exist in a folder, or to list directory contents. Returns an array of file/directory names.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the DIRECTORY to list (use '.' for current directory)"}
|
||||
}
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec!["read".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
let mut config = OwlenConfig::load(None).unwrap_or_default();
|
||||
let requested_name =
|
||||
env::var("OWLEN_PROVIDER").unwrap_or_else(|_| config.general.default_provider.clone());
|
||||
let provider_key = canonical_provider_name(&requested_name);
|
||||
if config.provider(&provider_key).is_none() {
|
||||
ensure_provider_config(&mut config, &provider_key);
|
||||
}
|
||||
let provider_cfg: ProviderConfig =
|
||||
config.provider(&provider_key).cloned().ok_or_else(|| {
|
||||
RpcError::internal_error(format!(
|
||||
"Provider '{provider_key}' not found in configuration"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider =
|
||||
OllamaProvider::from_config(&provider_key, &provider_cfg, Some(&config.general))
|
||||
.map_err(|e| {
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to init Ollama provider from config: {e}"
|
||||
))
|
||||
})?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(RpcError::internal_error(format!(
|
||||
"Unsupported provider type '{other}' for MCP LLM server"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
if let Ok(url) = env::var("OLLAMA_URL") {
|
||||
let provider = OllamaProvider::new(&url).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to init Ollama provider: {e}"))
|
||||
})?;
|
||||
return Ok(Arc::new(provider) as Arc<dyn Provider>);
|
||||
}
|
||||
|
||||
provider_from_config()
|
||||
}
|
||||
|
||||
fn canonical_provider_name(name: &str) -> String {
|
||||
let normalized = name.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama_local".to_string(),
|
||||
"ollama" | "ollama_local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" => "ollama_cloud".to_string(),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError> {
|
||||
let provider = create_provider()?;
|
||||
|
||||
let parameters = ChatParameters {
|
||||
temperature: args.temperature,
|
||||
max_tokens: args.max_tokens.map(|v| v as u32),
|
||||
stream: args.stream,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let request = ChatRequest {
|
||||
model: args.model,
|
||||
messages: args.messages,
|
||||
parameters,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
// Use streaming API and collect output
|
||||
let mut stream = provider
|
||||
.stream_prompt(request)
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Chat request failed: {}", e)))?;
|
||||
let mut content = String::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(resp) => {
|
||||
content.push_str(&resp.message.content);
|
||||
if resp.is_final {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(RpcError::internal_error(format!("Stream error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params for initialize"))?;
|
||||
let init: InitializeParams = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !init.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
init.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-llm-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(true),
|
||||
},
|
||||
};
|
||||
serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize init result: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
generate_text_descriptor(),
|
||||
resources_get_descriptor(),
|
||||
resources_list_descriptor(),
|
||||
];
|
||||
Ok(json!(tools))
|
||||
}
|
||||
// New method to list available Ollama models via the provider.
|
||||
methods::MODELS_LIST => {
|
||||
let provider = create_provider()?;
|
||||
let models = provider
|
||||
.list_models()
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to list models: {}", e)))?;
|
||||
serde_json::to_value(models).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize model list: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
// For streaming we will send incremental notifications directly from here.
|
||||
// The caller (main loop) will handle writing the final response.
|
||||
Err(RpcError::internal_error(
|
||||
"TOOLS_CALL should be handled in main loop for streaming",
|
||||
))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let root = env::current_dir()?; // not used but kept for parity
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break,
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let id = req.id.clone();
|
||||
// Streaming tool calls (generate_text) are handled specially to emit incremental notifications.
|
||||
if req.method == methods::TOOLS_CALL {
|
||||
// Parse the tool call
|
||||
let params = match &req.params {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params("Missing params for tool call"),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let call: McpToolCall = match serde_json::from_value(params.clone()) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params(format!("Invalid tool call: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Dispatch based on the requested tool name.
|
||||
// Handle resources tools manually.
|
||||
if call.name.starts_with("resources_get") {
|
||||
let path = call
|
||||
.arguments
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(content) => {
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(content),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize resource response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Failed to read file: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if call.name.starts_with("resources_list") {
|
||||
let path = call
|
||||
.arguments
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(".");
|
||||
match std::fs::read_dir(path) {
|
||||
Ok(entries) => {
|
||||
let mut names = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
if let Some(name) = entry.file_name().to_str() {
|
||||
names.push(name.to_string());
|
||||
}
|
||||
}
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(names),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize directory listing: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Failed to list dir: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expect generate_text tool for the remaining path.
|
||||
if call.name != "generate_text" {
|
||||
let err_resp =
|
||||
RpcErrorResponse::new(id.clone(), RpcError::tool_not_found(&call.name));
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
let args: GenerateTextArgs =
|
||||
match serde_json::from_value(call.arguments.clone()) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params(format!("Invalid arguments: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize provider and start streaming
|
||||
let provider = match create_provider() {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to initialize provider: {:?}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let parameters = ChatParameters {
|
||||
temperature: args.temperature,
|
||||
max_tokens: args.max_tokens.map(|v| v as u32),
|
||||
stream: true,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
let request = ChatRequest {
|
||||
model: args.model,
|
||||
messages: args.messages,
|
||||
parameters,
|
||||
tools: None,
|
||||
};
|
||||
let mut stream = match provider.stream_prompt(request).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Chat request failed: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Accumulate full content while sending incremental progress notifications
|
||||
let mut final_content = String::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(resp) => {
|
||||
// Append chunk to the final content buffer
|
||||
final_content.push_str(&resp.message.content);
|
||||
// Emit a progress notification for the UI
|
||||
let notif = RpcNotification::new(
|
||||
"tools/call/progress",
|
||||
Some(json!({ "content": resp.message.content })),
|
||||
);
|
||||
let s = serde_json::to_string(¬if)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
if resp.is_final {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Stream error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// After streaming, send the final tool response containing the full content
|
||||
let final_output = final_content.clone();
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(final_output),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize final streaming response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
// Non‑streaming requests are handled by the generic handler
|
||||
match handle_request(&req).await {
|
||||
Ok(res) => {
|
||||
let resp = RpcResponse::new(id, res);
|
||||
let s = serde_json::to_string(&resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(err) => {
|
||||
let err_resp = RpcErrorResponse::new(id, err);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Read error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-prompt-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "MCP server that renders prompt templates (YAML) for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
handlebars = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "owlen_mcp_prompt_server"
|
||||
path = "src/lib.rs"
|
||||
@@ -1,415 +0,0 @@
|
||||
//! MCP server for rendering prompt templates with YAML storage and Handlebars rendering.
|
||||
//!
|
||||
//! Templates are stored in `~/.config/owlen/prompts/` as YAML files.
|
||||
//! Provides full Handlebars templating support for dynamic prompt generation.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use handlebars::Handlebars;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
/// Prompt template definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTemplate {
|
||||
/// Template name
|
||||
pub name: String,
|
||||
/// Template version
|
||||
pub version: String,
|
||||
/// Optional mode restriction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mode: Option<String>,
|
||||
/// Handlebars template content
|
||||
pub template: String,
|
||||
/// Template description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Prompt server managing templates
|
||||
pub struct PromptServer {
|
||||
templates: Arc<RwLock<HashMap<String, PromptTemplate>>>,
|
||||
handlebars: Handlebars<'static>,
|
||||
templates_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PromptServer {
|
||||
/// Create a new prompt server
|
||||
pub fn new() -> Result<Self> {
|
||||
let templates_dir = Self::get_templates_dir()?;
|
||||
|
||||
// Create templates directory if it doesn't exist
|
||||
if !templates_dir.exists() {
|
||||
fs::create_dir_all(&templates_dir)?;
|
||||
Self::create_default_templates(&templates_dir)?;
|
||||
}
|
||||
|
||||
let mut server = Self {
|
||||
templates: Arc::new(RwLock::new(HashMap::new())),
|
||||
handlebars: Handlebars::new(),
|
||||
templates_dir,
|
||||
};
|
||||
|
||||
// Load all templates
|
||||
server.load_templates()?;
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
|
||||
/// Get the templates directory path
|
||||
fn get_templates_dir() -> Result<PathBuf> {
|
||||
let config_dir = dirs::config_dir().context("Could not determine config directory")?;
|
||||
Ok(config_dir.join("owlen").join("prompts"))
|
||||
}
|
||||
|
||||
/// Create default template examples
|
||||
fn create_default_templates(dir: &Path) -> Result<()> {
|
||||
let chat_mode_system = PromptTemplate {
|
||||
name: "chat_mode_system".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
mode: Some("chat".to_string()),
|
||||
description: Some("System prompt for chat mode".to_string()),
|
||||
template: r#"You are Owlen, a helpful AI assistant. You have access to these tools:
|
||||
{{#each tools}}
|
||||
- {{name}}: {{description}}
|
||||
{{/each}}
|
||||
|
||||
Use the ReAct pattern:
|
||||
THOUGHT: Your reasoning
|
||||
ACTION: tool_name
|
||||
ACTION_INPUT: {"param": "value"}
|
||||
|
||||
When you have enough information:
|
||||
FINAL_ANSWER: Your response"#
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
let code_mode_system = PromptTemplate {
|
||||
name: "code_mode_system".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
mode: Some("code".to_string()),
|
||||
description: Some("System prompt for code mode".to_string()),
|
||||
template: r#"You are Owlen in code mode, with full development capabilities. You have access to:
|
||||
{{#each tools}}
|
||||
- {{name}}: {{description}}
|
||||
{{/each}}
|
||||
|
||||
Use the ReAct pattern to solve coding tasks:
|
||||
THOUGHT: Analyze what needs to be done
|
||||
ACTION: tool_name (compile_project, run_tests, format_code, lint_code, etc.)
|
||||
ACTION_INPUT: {"param": "value"}
|
||||
|
||||
Continue iterating until the task is complete, then provide:
|
||||
FINAL_ANSWER: Summary of what was done"#
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
// Save templates
|
||||
let chat_path = dir.join("chat_mode_system.yaml");
|
||||
let code_path = dir.join("code_mode_system.yaml");
|
||||
|
||||
fs::write(chat_path, serde_yaml::to_string(&chat_mode_system)?)?;
|
||||
fs::write(code_path, serde_yaml::to_string(&code_mode_system)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load all templates from the templates directory
|
||||
fn load_templates(&mut self) -> Result<()> {
|
||||
let entries = fs::read_dir(&self.templates_dir)?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("yaml")
|
||||
|| path.extension().and_then(|s| s.to_str()) == Some("yml")
|
||||
{
|
||||
match self.load_template(&path) {
|
||||
Ok(template) => {
|
||||
// Register with Handlebars
|
||||
if let Err(e) = self
|
||||
.handlebars
|
||||
.register_template_string(&template.name, &template.template)
|
||||
{
|
||||
eprintln!(
|
||||
"Warning: Failed to register template {}: {}",
|
||||
template.name, e
|
||||
);
|
||||
} else {
|
||||
let mut templates = self.templates.blocking_write();
|
||||
templates.insert(template.name.clone(), template);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Failed to load template {:?}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a single template from file
|
||||
fn load_template(&self, path: &Path) -> Result<PromptTemplate> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
let template: PromptTemplate = serde_yaml::from_str(&content)?;
|
||||
Ok(template)
|
||||
}
|
||||
|
||||
/// Get a template by name
|
||||
pub async fn get_template(&self, name: &str) -> Option<PromptTemplate> {
|
||||
let templates = self.templates.read().await;
|
||||
templates.get(name).cloned()
|
||||
}
|
||||
|
||||
/// List all available templates
|
||||
pub async fn list_templates(&self) -> Vec<String> {
|
||||
let templates = self.templates.read().await;
|
||||
templates.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Render a template with given variables
|
||||
pub fn render_template(&self, name: &str, vars: &Value) -> Result<String> {
|
||||
self.handlebars
|
||||
.render(name, vars)
|
||||
.context("Failed to render template")
|
||||
}
|
||||
|
||||
/// Reload all templates from disk
|
||||
pub async fn reload_templates(&mut self) -> Result<()> {
|
||||
{
|
||||
let mut templates = self.templates.write().await;
|
||||
templates.clear();
|
||||
}
|
||||
self.handlebars = Handlebars::new();
|
||||
self.load_templates()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
let server = Arc::new(tokio::sync::Mutex::new(PromptServer::new()?));
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = handle_request(req.clone(), server.clone()).await;
|
||||
match resp {
|
||||
Ok(r) => {
|
||||
let s = serde_json::to_string(&r)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn handle_request(
|
||||
req: RpcRequest,
|
||||
server: Arc<tokio::sync::Mutex<PromptServer>>,
|
||||
) -> Result<RpcResponse, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params: InitializeParams =
|
||||
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-prompt-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
McpToolDescriptor {
|
||||
name: "get_prompt".to_string(),
|
||||
description: "Retrieve a prompt template by name".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Template name"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "render_prompt".to_string(),
|
||||
description: "Render a prompt template with Handlebars variables".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Template name"},
|
||||
"vars": {"type": "object", "description": "Variables for Handlebars rendering"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "list_prompts".to_string(),
|
||||
description: "List all available prompt templates".to_string(),
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "reload_prompts".to_string(),
|
||||
description: "Reload all prompts from disk".to_string(),
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
];
|
||||
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
let call: McpToolCall = serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||
|
||||
let result = match call.name.as_str() {
|
||||
"get_prompt" => {
|
||||
let name = call
|
||||
.arguments
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||
|
||||
let srv = server.lock().await;
|
||||
match srv.get_template(name).await {
|
||||
Some(template) => match serde_json::to_value(template) {
|
||||
Ok(serialized) => {
|
||||
json!({"success": true, "template": serialized})
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(RpcError::internal_error(format!(
|
||||
"Failed to serialize template '{}': {}",
|
||||
name, e
|
||||
)));
|
||||
}
|
||||
},
|
||||
None => json!({"success": false, "error": "Template not found"}),
|
||||
}
|
||||
}
|
||||
"render_prompt" => {
|
||||
let name = call
|
||||
.arguments
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||
|
||||
let default_vars = json!({});
|
||||
let vars = call.arguments.get("vars").unwrap_or(&default_vars);
|
||||
|
||||
let srv = server.lock().await;
|
||||
match srv.render_template(name, vars) {
|
||||
Ok(rendered) => json!({"success": true, "rendered": rendered}),
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
}
|
||||
"list_prompts" => {
|
||||
let srv = server.lock().await;
|
||||
let templates = srv.list_templates().await;
|
||||
json!({"success": true, "templates": templates})
|
||||
}
|
||||
"reload_prompts" => {
|
||||
let mut srv = server.lock().await;
|
||||
match srv.reload_templates().await {
|
||||
Ok(_) => json!({"success": true, "message": "Prompts reloaded"}),
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
}
|
||||
_ => return Err(RpcError::method_not_found(&call.name)),
|
||||
};
|
||||
|
||||
let resp = McpToolResponse {
|
||||
name: call.name,
|
||||
success: result
|
||||
.get("success")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
output: result,
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
prompt: |
|
||||
Hello {{name}}!
|
||||
Your role is: {{role}}.
|
||||
@@ -1,12 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
path-clean = "1.0"
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
@@ -1,246 +0,0 @@
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, is_compatible,
|
||||
};
|
||||
use path_clean::PathClean;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct FileArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct WriteArgs {
|
||||
path: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
async fn handle_request(req: &RpcRequest, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
match req.method.as_str() {
|
||||
"initialize" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params for initialize"))?;
|
||||
|
||||
let init_params: InitializeParams =
|
||||
serde_json::from_value(params.clone()).map_err(|e| {
|
||||
RpcError::invalid_params(format!("Invalid initialize params: {}", e))
|
||||
})?;
|
||||
|
||||
// Check protocol version compatibility
|
||||
if !is_compatible(&init_params.protocol_version, PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
init_params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Build initialization result
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(false),
|
||||
supports_resources: Some(true), // Supports read, write, delete
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize result: {}", e))
|
||||
})?)
|
||||
}
|
||||
"resources_list" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_list(&args.path, root).await
|
||||
}
|
||||
"resources_get" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_get(&args.path, root).await
|
||||
}
|
||||
"resources_write" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: WriteArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_write(&args.path, &args.content, root).await
|
||||
}
|
||||
"resources_delete" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_delete(&args.path, root).await
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_path(path: &str, root: &Path) -> Result<PathBuf, RpcError> {
|
||||
let path = Path::new(path);
|
||||
let path = if path.is_absolute() {
|
||||
path.strip_prefix("/")
|
||||
.map_err(|_| RpcError::invalid_params("Invalid path"))?
|
||||
.to_path_buf()
|
||||
} else {
|
||||
path.to_path_buf()
|
||||
};
|
||||
|
||||
let full_path = root.join(path).clean();
|
||||
|
||||
if !full_path.starts_with(root) {
|
||||
return Err(RpcError::path_traversal());
|
||||
}
|
||||
|
||||
Ok(full_path)
|
||||
}
|
||||
|
||||
async fn resources_list(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
|
||||
let entries = fs::read_dir(full_path).map_err(|e| {
|
||||
RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
format!("Failed to read directory: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry.map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to read directory entry: {}", e))
|
||||
})?;
|
||||
result.push(entry.file_name().to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
Ok(serde_json::json!(result))
|
||||
}
|
||||
|
||||
async fn resources_get(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
|
||||
let content = fs::read_to_string(full_path).map_err(|e| {
|
||||
RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
format!("Failed to read file: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(serde_json::json!(content))
|
||||
}
|
||||
|
||||
async fn resources_write(
|
||||
path: &str,
|
||||
content: &str,
|
||||
root: &Path,
|
||||
) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = full_path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to create parent directories: {}", e))
|
||||
})?;
|
||||
}
|
||||
std::fs::write(full_path, content)
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to write file: {}", e)))?;
|
||||
Ok(serde_json::json!(null))
|
||||
}
|
||||
|
||||
async fn resources_delete(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
if full_path.is_file() {
|
||||
std::fs::remove_file(full_path)
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to delete file: {}", e)))?;
|
||||
Ok(serde_json::json!(null))
|
||||
} else {
|
||||
Err(RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
"Path does not refer to a file",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let root = env::current_dir()?;
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let resp_str = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let request_id = req.id.clone();
|
||||
|
||||
match handle_request(&req, &root).await {
|
||||
Ok(result) => {
|
||||
let resp = RpcResponse::new(request_id, result);
|
||||
let resp_str = serde_json::to_string(&resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(error) => {
|
||||
let err_resp = RpcErrorResponse::new(request_id, error);
|
||||
let resp_str = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle read error
|
||||
eprintln!("Error reading from stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,63 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
description = "Command-line interface for OWLEN LLM client"
|
||||
|
||||
[features]
|
||||
default = ["chat-client"]
|
||||
chat-client = ["owlen-tui"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen"
|
||||
path = "src/main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-code"
|
||||
path = "src/code_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-agent"
|
||||
path = "src/agent_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-providers = { path = "../owlen-providers" }
|
||||
# Optional TUI dependency, enabled by the "chat-client" feature.
|
||||
owlen-tui = { path = "../owlen-tui", optional = true }
|
||||
log = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# CLI framework
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
|
||||
# TUI framework
|
||||
ratatui = { workspace = true }
|
||||
crossterm = { workspace = true }
|
||||
|
||||
# Utilities
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
mime_guess = { workspace = true }
|
||||
image = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
@@ -1,15 +0,0 @@
|
||||
# Owlen CLI
|
||||
|
||||
This crate is the command-line entry point for the Owlen application.
|
||||
|
||||
It is responsible for:
|
||||
|
||||
- Parsing command-line arguments.
|
||||
- Loading the configuration.
|
||||
- Initializing the providers.
|
||||
- Starting the `owlen-tui` application.
|
||||
|
||||
There are two binaries:
|
||||
|
||||
- `owlen`: The main chat application.
|
||||
- `owlen-code`: A specialized version for code-related tasks.
|
||||
@@ -1,31 +0,0 @@
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
const MIN_VERSION: (u32, u32, u32) = (1, 75, 0);
|
||||
|
||||
let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".into());
|
||||
let output = Command::new(&rustc)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.expect("failed to invoke rustc");
|
||||
|
||||
let version_line = String::from_utf8_lossy(&output.stdout);
|
||||
let version_str = version_line.split_whitespace().nth(1).unwrap_or("0.0.0");
|
||||
let sanitized = version_str.split('-').next().unwrap_or(version_str);
|
||||
|
||||
let mut parts = sanitized
|
||||
.split('.')
|
||||
.map(|part| part.parse::<u32>().unwrap_or(0));
|
||||
let current = (
|
||||
parts.next().unwrap_or(0),
|
||||
parts.next().unwrap_or(0),
|
||||
parts.next().unwrap_or(0),
|
||||
);
|
||||
|
||||
if current < MIN_VERSION {
|
||||
panic!(
|
||||
"owlen requires rustc {}.{}.{} or newer (found {version_line})",
|
||||
MIN_VERSION.0, MIN_VERSION.1, MIN_VERSION.2
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,285 +0,0 @@
|
||||
//! Simple entry point for the ReAct agentic executor.
|
||||
//!
|
||||
//! Usage: `owlen-agent "<prompt>" [--model <model>] [--max-iter <n>]`
|
||||
//!
|
||||
//! This binary demonstrates Phase 4 without the full TUI. It creates an
|
||||
//! OllamaProvider, a RemoteMcpClient, runs the AgentExecutor and prints the
|
||||
//! final answer.
|
||||
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use base64::{Engine, engine::general_purpose::STANDARD as BASE64_STANDARD};
|
||||
use clap::{Parser, builder::ValueHint};
|
||||
use image::{self, GenericImageView, imageops::FilterType};
|
||||
use owlen_cli::agent::{AgentConfig, AgentExecutor};
|
||||
use owlen_core::{mcp::remote_client::RemoteMcpClient, types::MessageAttachment};
|
||||
use tokio::fs;
|
||||
|
||||
const MAX_ATTACHMENT_BYTES: u64 = 8 * 1024 * 1024;
|
||||
const ATTACHMENT_ASCII_WIDTH: u32 = 24;
|
||||
const ATTACHMENT_ASCII_HEIGHT: u32 = 12;
|
||||
const ATTACHMENT_TEXT_PREVIEW_LINES: usize = 12;
|
||||
const ATTACHMENT_TEXT_PREVIEW_WIDTH: usize = 80;
|
||||
const ATTACHMENT_INLINE_PREVIEW_LINES: usize = 6;
|
||||
|
||||
/// Command‑line arguments for the agent binary.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "owlen-agent",
|
||||
author,
|
||||
version,
|
||||
about = "Run the ReAct agent via MCP"
|
||||
)]
|
||||
struct Args {
|
||||
/// The initial user query.
|
||||
prompt: String,
|
||||
/// Paths to files that should be sent with the initial turn.
|
||||
#[arg(long = "attach", short = 'a', value_name = "PATH", value_hint = ValueHint::FilePath)]
|
||||
attachments: Vec<PathBuf>,
|
||||
/// Model to use (defaults to Ollama default).
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
/// Maximum ReAct iterations.
|
||||
#[arg(long, default_value_t = 10)]
|
||||
max_iter: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let Args {
|
||||
prompt,
|
||||
attachments: attachment_paths,
|
||||
model,
|
||||
max_iter,
|
||||
} = args;
|
||||
|
||||
let attachments = load_attachments(&attachment_paths).await?;
|
||||
if !attachments.is_empty() {
|
||||
println!(
|
||||
"Attaching {} {}:",
|
||||
attachments.len(),
|
||||
if attachments.len() == 1 {
|
||||
"artifact"
|
||||
} else {
|
||||
"artifacts"
|
||||
}
|
||||
);
|
||||
render_attachment_previews(&attachments);
|
||||
}
|
||||
|
||||
// Initialise the MCP LLM client – it implements Provider and talks to the
|
||||
// MCP LLM server which wraps Ollama. This ensures all communication goes
|
||||
// through the MCP architecture (Phase 10 requirement).
|
||||
let provider = Arc::new(RemoteMcpClient::new().await?);
|
||||
|
||||
// The MCP client also serves as the tool client for resource operations
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: max_iter,
|
||||
model: model.unwrap_or_else(|| "llama3.2:latest".to_string()),
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
match executor.run_with_attachments(prompt, attachments).await {
|
||||
Ok(result) => {
|
||||
println!("\n✓ Agent completed in {} iterations", result.iterations);
|
||||
println!("\nFinal answer:\n{}", result.answer);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(anyhow::anyhow!(e)),
|
||||
}
|
||||
}
|
||||
|
||||
async fn load_attachments(paths: &[PathBuf]) -> anyhow::Result<Vec<MessageAttachment>> {
|
||||
let mut attachments = Vec::new();
|
||||
for path in paths {
|
||||
let attachment = load_attachment(path).await?;
|
||||
attachments.push(attachment);
|
||||
}
|
||||
Ok(attachments)
|
||||
}
|
||||
|
||||
async fn load_attachment(path: &Path) -> anyhow::Result<MessageAttachment> {
|
||||
let metadata = fs::metadata(path)
|
||||
.await
|
||||
.with_context(|| format!("Unable to inspect {}", path.display()))?;
|
||||
if !metadata.is_file() {
|
||||
return Err(anyhow::anyhow!("{} is not a regular file", path.display()));
|
||||
}
|
||||
if metadata.len() > MAX_ATTACHMENT_BYTES {
|
||||
return Err(anyhow::anyhow!(
|
||||
"Attachments are limited to {} (requested {}): {}",
|
||||
format_attachment_size(MAX_ATTACHMENT_BYTES),
|
||||
format_attachment_size(metadata.len()),
|
||||
path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let bytes = fs::read(path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read {}", path.display()))?;
|
||||
let mime = mime_guess::from_path(path).first_or_octet_stream();
|
||||
let mime_string = mime.essence_str().to_string();
|
||||
let file_name = path
|
||||
.file_name()
|
||||
.and_then(|value| value.to_str())
|
||||
.unwrap_or("attachment")
|
||||
.to_string();
|
||||
|
||||
let is_text = mime_string.starts_with("text/") || std::str::from_utf8(&bytes).is_ok();
|
||||
let mut preview_lines = Vec::new();
|
||||
|
||||
let mut attachment = if is_text {
|
||||
let text = String::from_utf8_lossy(&bytes).into_owned();
|
||||
preview_lines = preview_lines_for_text(&text);
|
||||
let mut attachment =
|
||||
MessageAttachment::from_text(Some(file_name.clone()), mime_string.clone(), text);
|
||||
attachment.size_bytes = Some(metadata.len());
|
||||
attachment
|
||||
} else {
|
||||
if mime_string.starts_with("image/")
|
||||
&& let Some(lines) = preview_lines_for_image(&bytes)
|
||||
{
|
||||
preview_lines = lines;
|
||||
}
|
||||
let encoded = BASE64_STANDARD.encode(&bytes);
|
||||
MessageAttachment::from_base64(
|
||||
file_name.clone(),
|
||||
mime_string.clone(),
|
||||
encoded,
|
||||
Some(metadata.len()),
|
||||
)
|
||||
};
|
||||
|
||||
attachment.size_bytes = Some(metadata.len());
|
||||
attachment = attachment.with_source_path(path.to_path_buf());
|
||||
if !preview_lines.is_empty() {
|
||||
attachment = attachment.with_preview_lines(preview_lines);
|
||||
}
|
||||
|
||||
Ok(attachment)
|
||||
}
|
||||
|
||||
fn render_attachment_previews(attachments: &[MessageAttachment]) {
|
||||
for (idx, attachment) in attachments.iter().enumerate() {
|
||||
println!(" {}. {}", idx + 1, summarize_attachment(attachment));
|
||||
if let Some(lines) = attachment.preview_lines.as_ref() {
|
||||
for line in lines.iter().take(ATTACHMENT_INLINE_PREVIEW_LINES) {
|
||||
println!(" {}", line);
|
||||
}
|
||||
if lines.len() > ATTACHMENT_INLINE_PREVIEW_LINES {
|
||||
println!(" …");
|
||||
}
|
||||
}
|
||||
}
|
||||
if !attachments.is_empty() {
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
fn summarize_attachment(attachment: &MessageAttachment) -> String {
|
||||
let icon = if attachment.is_image() {
|
||||
"📷"
|
||||
} else if attachment
|
||||
.mime_type
|
||||
.to_ascii_lowercase()
|
||||
.starts_with("text/")
|
||||
{
|
||||
"📄"
|
||||
} else {
|
||||
"📎"
|
||||
};
|
||||
let name = attachment
|
||||
.name
|
||||
.as_deref()
|
||||
.unwrap_or(attachment.mime_type.as_str());
|
||||
let mut parts = vec![format!("{icon} {name}"), attachment.mime_type.clone()];
|
||||
if let Some(size) = attachment.size_bytes {
|
||||
parts.push(format_attachment_size(size));
|
||||
}
|
||||
parts.join(" · ")
|
||||
}
|
||||
|
||||
fn format_attachment_size(bytes: u64) -> String {
|
||||
const UNITS: &[&str] = &["B", "KB", "MB", "GB", "TB"];
|
||||
let mut value = bytes as f64;
|
||||
let mut index = 0usize;
|
||||
while value >= 1024.0 && index < UNITS.len() - 1 {
|
||||
value /= 1024.0;
|
||||
index += 1;
|
||||
}
|
||||
if index == 0 {
|
||||
format!("{bytes} {}", UNITS[index])
|
||||
} else {
|
||||
format!("{value:.1} {}", UNITS[index])
|
||||
}
|
||||
}
|
||||
|
||||
fn preview_lines_for_text(text: &str) -> Vec<String> {
|
||||
let mut lines = Vec::new();
|
||||
for raw in text.lines().take(ATTACHMENT_TEXT_PREVIEW_LINES) {
|
||||
let trimmed = raw.trim_end();
|
||||
if trimmed.is_empty() {
|
||||
lines.push(String::new());
|
||||
continue;
|
||||
}
|
||||
let mut snippet = trimmed
|
||||
.chars()
|
||||
.take(ATTACHMENT_TEXT_PREVIEW_WIDTH)
|
||||
.collect::<String>();
|
||||
if trimmed.chars().count() > ATTACHMENT_TEXT_PREVIEW_WIDTH {
|
||||
snippet.push('…');
|
||||
}
|
||||
lines.push(snippet);
|
||||
}
|
||||
|
||||
if lines.is_empty() {
|
||||
lines.push("(empty attachment)".to_string());
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn preview_lines_for_image(bytes: &[u8]) -> Option<Vec<String>> {
|
||||
let image = image::load_from_memory(bytes).ok()?;
|
||||
let (width, height) = image.dimensions();
|
||||
let mut lines = Vec::new();
|
||||
lines.push(format!("{width} × {height} px"));
|
||||
|
||||
let target_width = ATTACHMENT_ASCII_WIDTH;
|
||||
let target_height = ATTACHMENT_ASCII_HEIGHT;
|
||||
let scale = (target_width as f32 / width as f32)
|
||||
.min(target_height as f32 / height as f32)
|
||||
.clamp(0.05, 1.0);
|
||||
let scaled_width = (width as f32 * scale).max(1.0).round() as u32;
|
||||
let scaled_height = (height as f32 * scale).max(1.0).round() as u32;
|
||||
let resized = image
|
||||
.resize_exact(
|
||||
scaled_width.max(1),
|
||||
scaled_height.max(1),
|
||||
FilterType::Triangle,
|
||||
)
|
||||
.to_luma8();
|
||||
|
||||
const PALETTE: [char; 10] = [' ', '.', ':', '-', '=', '+', '*', '#', '%', '@'];
|
||||
for y in 0..resized.height() {
|
||||
let mut row = String::with_capacity((resized.width() as usize) * 2);
|
||||
for x in 0..resized.width() {
|
||||
let luminance = resized.get_pixel(x, y)[0] as usize;
|
||||
let idx = luminance * (PALETTE.len() - 1) / 255;
|
||||
let ch = PALETTE[idx];
|
||||
row.push(ch);
|
||||
row.push(ch);
|
||||
}
|
||||
lines.push(row);
|
||||
}
|
||||
|
||||
Some(lines)
|
||||
}
|
||||
@@ -1,340 +0,0 @@
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures::stream;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
provider::ProviderManager,
|
||||
providers::OllamaProvider,
|
||||
session::{ControllerEvent, SessionController},
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
};
|
||||
use owlen_tui::{
|
||||
ChatApp, SessionEvent,
|
||||
app::App as RuntimeApp,
|
||||
config,
|
||||
tui_controller::{TuiController, TuiRequest},
|
||||
ui,
|
||||
};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::commands::cloud::{load_runtime_credentials, set_env_var};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct LaunchOptions {
|
||||
pub disable_auto_compress: bool,
|
||||
}
|
||||
|
||||
pub async fn launch(initial_mode: Mode, options: LaunchOptions) -> Result<()> {
|
||||
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
let _ = cfg.refresh_mcp_servers(None);
|
||||
|
||||
if options.disable_auto_compress {
|
||||
cfg.chat.auto_compress = false;
|
||||
}
|
||||
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
let provider = build_provider(&cfg).await?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.effective_mcp_servers().is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let (controller_event_tx, controller_event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
let controller = SessionController::new(
|
||||
provider,
|
||||
cfg,
|
||||
storage.clone(),
|
||||
tui_controller,
|
||||
false,
|
||||
Some(controller_event_tx),
|
||||
)
|
||||
.await?;
|
||||
let provider_manager = Arc::new(ProviderManager::default());
|
||||
let mut runtime = RuntimeApp::new(provider_manager);
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice.clone() {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
if options.disable_auto_compress {
|
||||
app.append_system_status("Auto compression off");
|
||||
}
|
||||
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut runtime, &mut app, &mut session_rx).await;
|
||||
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn build_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server).await
|
||||
} else {
|
||||
RemoteMcpClient::new().await
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => Ok(Arc::new(client) as Arc<dyn Provider>),
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow!(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||
anyhow!("[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"")
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server).await?;
|
||||
Ok(Arc::new(client) as Arc<dyn Provider>)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider =
|
||||
OllamaProvider::from_config(&provider_name, provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
runtime: &mut RuntimeApp,
|
||||
app: &mut ChatApp,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
let mut render = |terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut ChatApp|
|
||||
-> Result<()> {
|
||||
terminal.draw(|f| ui::render_chat(f, state))?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
runtime.run(terminal, app, session_rx, &mut render).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//! Owlen CLI entrypoint optimised for code-first workflows.
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::Result;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
bootstrap::launch(Mode::Code, bootstrap::LaunchOptions::default()).await
|
||||
}
|
||||
@@ -1,470 +0,0 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use clap::Subcommand;
|
||||
use owlen_core::LlmProvider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{
|
||||
self as core_config, Config, LEGACY_OLLAMA_CLOUD_API_KEY_ENV,
|
||||
LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||||
};
|
||||
use owlen_core::credentials::{ApiCredentials, CredentialManager, OLLAMA_CLOUD_CREDENTIAL_ID};
|
||||
use owlen_core::encryption;
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::storage::StorageManager;
|
||||
use serde_json::Value;
|
||||
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const CLOUD_ENDPOINT_KEY: &str = OLLAMA_CLOUD_ENDPOINT_KEY;
|
||||
const CLOUD_PROVIDER_KEY: &str = "ollama_cloud";
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum CloudCommand {
|
||||
/// Configure Ollama Cloud credentials
|
||||
Setup {
|
||||
/// API key passed directly on the command line
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
/// Override the cloud endpoint (default: https://ollama.com)
|
||||
#[arg(long)]
|
||||
endpoint: Option<String>,
|
||||
/// Provider name to configure (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
/// Overwrite the provider base URL with the cloud endpoint
|
||||
#[arg(long)]
|
||||
force_cloud_base_url: bool,
|
||||
},
|
||||
/// Check connectivity to Ollama Cloud
|
||||
Status {
|
||||
/// Provider name to check (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// List available cloud-hosted models
|
||||
Models {
|
||||
/// Provider name to query (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// Remove stored Ollama Cloud credentials
|
||||
Logout {
|
||||
/// Provider name to clear (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn run_cloud_command(command: CloudCommand) -> Result<()> {
|
||||
match command {
|
||||
CloudCommand::Setup {
|
||||
api_key,
|
||||
endpoint,
|
||||
provider,
|
||||
force_cloud_base_url,
|
||||
} => setup(provider, api_key, endpoint, force_cloud_base_url).await,
|
||||
CloudCommand::Status { provider } => status(provider).await,
|
||||
CloudCommand::Models { provider } => models(provider).await,
|
||||
CloudCommand::Logout { provider } => logout(provider).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup(
|
||||
provider: String,
|
||||
api_key: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
force_cloud_base_url: bool,
|
||||
) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let endpoint =
|
||||
normalize_endpoint(&endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string()));
|
||||
|
||||
let base_changed = {
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, &endpoint, force_cloud_base_url)
|
||||
};
|
||||
|
||||
let mut credential_manager: Option<Arc<CredentialManager>> = None;
|
||||
if config.privacy.encrypt_local_data {
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
credential_manager = Some(unlock_credential_manager(&config, storage)?);
|
||||
}
|
||||
|
||||
let mut key_opt = api_key.filter(|value| !value.trim().is_empty());
|
||||
|
||||
if key_opt.is_none() {
|
||||
if let Some(manager) = credential_manager.as_ref() {
|
||||
if let Some(credentials) = manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await? {
|
||||
key_opt = Some(credentials.api_key);
|
||||
}
|
||||
} else if let Some(existing) = config
|
||||
.provider(&provider)
|
||||
.and_then(|cfg| cfg.api_key.clone())
|
||||
{
|
||||
key_opt = Some(existing);
|
||||
}
|
||||
}
|
||||
|
||||
let key = key_opt
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
.ok_or_else(|| {
|
||||
anyhow!(
|
||||
"API key is required when configuring provider `{provider}`. \
|
||||
Supply the --api-key flag, set providers.{provider}.api_key in config.toml, \
|
||||
or populate the credential vault."
|
||||
)
|
||||
})?;
|
||||
|
||||
if let Some(manager) = credential_manager.clone() {
|
||||
let credentials = ApiCredentials {
|
||||
api_key: key.clone(),
|
||||
endpoint: endpoint.clone(),
|
||||
};
|
||||
manager
|
||||
.store_credentials(OLLAMA_CLOUD_CREDENTIAL_ID, &credentials)
|
||||
.await?;
|
||||
// Ensure plaintext key is not persisted to disk.
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
}
|
||||
} else if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = Some(key.clone());
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Saved Ollama configuration for provider '{provider}'.");
|
||||
if config.privacy.encrypt_local_data {
|
||||
println!("API key stored securely in the encrypted credential vault.");
|
||||
} else {
|
||||
println!("API key stored in plaintext configuration (encryption disabled).");
|
||||
}
|
||||
if !force_cloud_base_url && !base_changed {
|
||||
println!(
|
||||
"Local base URL preserved; cloud endpoint stored as {}.",
|
||||
CLOUD_ENDPOINT_KEY
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
let manager = if config.privacy.encrypt_local_data {
|
||||
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let api_key = hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint.clone());
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.health_check().await {
|
||||
Ok(_) => {
|
||||
println!("✓ Connected to {provider} ({})", endpoint);
|
||||
if api_key.is_none() && config.privacy.encrypt_local_data {
|
||||
println!(
|
||||
"Warning: No API key stored; connection succeeded via environment variables."
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("✗ Failed to reach {provider}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn models(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
let manager = if config.privacy.encrypt_local_data {
|
||||
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint);
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&provider, &runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.list_models().await {
|
||||
Ok(models) => {
|
||||
if models.is_empty() {
|
||||
println!("No cloud models reported by '{}'.", provider);
|
||||
} else {
|
||||
println!("Models available via '{}':", provider);
|
||||
for model in models {
|
||||
if let Some(description) = &model.description {
|
||||
println!(" - {} ({})", model.id, description);
|
||||
} else {
|
||||
println!(" - {}", model.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
bail!("Failed to list models: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn logout(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
|
||||
if config.privacy.encrypt_local_data {
|
||||
let manager = unlock_credential_manager(&config, storage.clone())?;
|
||||
manager
|
||||
.delete_credentials(OLLAMA_CLOUD_CREDENTIAL_ID)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
entry.enabled = false;
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Cleared credentials for provider '{provider}'.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_provider_entry<'a>(config: &'a mut Config, provider: &str) -> &'a mut ProviderConfig {
|
||||
core_config::ensure_provider_config_mut(config, provider)
|
||||
}
|
||||
|
||||
fn configure_cloud_endpoint(entry: &mut ProviderConfig, endpoint: &str, force: bool) -> bool {
|
||||
let normalized = normalize_endpoint(endpoint);
|
||||
let previous_base = entry.base_url.clone();
|
||||
entry.extra.insert(
|
||||
CLOUD_ENDPOINT_KEY.to_string(),
|
||||
Value::String(normalized.clone()),
|
||||
);
|
||||
|
||||
let should_update_env = match entry.api_key_env.as_deref() {
|
||||
None => true,
|
||||
Some(value) => {
|
||||
value.eq_ignore_ascii_case(LEGACY_OLLAMA_CLOUD_API_KEY_ENV)
|
||||
|| value.eq_ignore_ascii_case(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV)
|
||||
}
|
||||
};
|
||||
if should_update_env {
|
||||
entry.api_key_env = Some(OLLAMA_API_KEY_ENV.to_string());
|
||||
}
|
||||
|
||||
if force
|
||||
|| entry
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
entry.base_url = Some(normalized.clone());
|
||||
}
|
||||
|
||||
if force {
|
||||
entry.enabled = true;
|
||||
}
|
||||
|
||||
entry.base_url != previous_base
|
||||
}
|
||||
|
||||
fn resolve_cloud_endpoint(cfg: &ProviderConfig) -> Option<String> {
|
||||
if let Some(value) = cfg
|
||||
.extra
|
||||
.get(CLOUD_ENDPOINT_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.map(normalize_endpoint)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
|
||||
cfg.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim_end_matches('/').to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
fn normalize_endpoint(endpoint: &str) -> String {
|
||||
let trimmed = endpoint.trim().trim_end_matches('/');
|
||||
if trimmed.is_empty() {
|
||||
DEFAULT_CLOUD_ENDPOINT.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_name(provider: &str) -> String {
|
||||
let normalized = provider.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama_cloud" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
value => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_env_var<K, V>(key: K, value: V)
|
||||
where
|
||||
K: AsRef<OsStr>,
|
||||
V: AsRef<OsStr>,
|
||||
{
|
||||
// Safety: the CLI updates process-wide environment variables during startup while no
|
||||
// other threads are mutating the environment.
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_env_if_missing(var: &str, value: &str) {
|
||||
if std::env::var(var)
|
||||
.map(|v| v.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
set_env_var(var, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn unlock_credential_manager(
|
||||
config: &Config,
|
||||
storage: Arc<StorageManager>,
|
||||
) -> Result<Arc<CredentialManager>> {
|
||||
if !config.privacy.encrypt_local_data {
|
||||
bail!("Credential manager requested but encryption is disabled");
|
||||
}
|
||||
|
||||
let secure_path = vault_path(&storage)?;
|
||||
let handle = unlock_vault(&secure_path)?;
|
||||
let master_key = Arc::new(handle.data.master_key.clone());
|
||||
Ok(Arc::new(CredentialManager::new(
|
||||
storage,
|
||||
master_key.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
fn vault_path(storage: &StorageManager) -> Result<PathBuf> {
|
||||
let base_dir = storage
|
||||
.database_path()
|
||||
.parent()
|
||||
.map(|p| p.to_path_buf())
|
||||
.or_else(dirs::data_local_dir)
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
Ok(base_dir.join("encrypted_data.json"))
|
||||
}
|
||||
|
||||
fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||
encryption::unlock(path.to_path_buf())
|
||||
}
|
||||
|
||||
async fn hydrate_api_key(
|
||||
config: &mut Config,
|
||||
manager: Option<&Arc<CredentialManager>>,
|
||||
) -> Result<Option<String>> {
|
||||
let credentials = match manager {
|
||||
Some(manager) => manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?,
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(credentials) = credentials {
|
||||
let key = credentials.api_key.trim().to_string();
|
||||
if !key.is_empty() {
|
||||
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||
}
|
||||
|
||||
let cfg = core_config::ensure_provider_config_mut(config, CLOUD_PROVIDER_KEY);
|
||||
configure_cloud_endpoint(cfg, &credentials.endpoint, false);
|
||||
return Ok(Some(key));
|
||||
}
|
||||
|
||||
if let Some(key) = config
|
||||
.provider(CLOUD_PROVIDER_KEY)
|
||||
.and_then(|cfg| cfg.api_key.as_ref())
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||
return Ok(Some(key.to_string()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn load_runtime_credentials(
|
||||
config: &mut Config,
|
||||
storage: Arc<StorageManager>,
|
||||
) -> Result<()> {
|
||||
if config.privacy.encrypt_local_data {
|
||||
let manager = unlock_credential_manager(config, storage.clone())?;
|
||||
hydrate_api_key(config, Some(&manager)).await?;
|
||||
} else {
|
||||
hydrate_api_key(config, None).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn canonicalises_provider_names() {
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(""), CLOUD_PROVIDER_KEY);
|
||||
}
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
//! Command implementations for the `owlen` CLI.
|
||||
|
||||
pub mod cloud;
|
||||
pub mod providers;
|
||||
pub mod repo;
|
||||
pub mod security;
|
||||
pub mod tools;
|
||||
@@ -1,800 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand};
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{self as core_config, Config};
|
||||
use owlen_core::provider::{
|
||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::storage::StorageManager;
|
||||
use owlen_core::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
use super::cloud;
|
||||
|
||||
/// CLI subcommands for provider management.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ProvidersCommand {
|
||||
/// List configured providers and their metadata.
|
||||
List,
|
||||
/// Run health checks against providers.
|
||||
Status {
|
||||
/// Optional provider identifier to check.
|
||||
#[arg(value_name = "PROVIDER")]
|
||||
provider: Option<String>,
|
||||
},
|
||||
/// Enable a provider in the configuration.
|
||||
Enable {
|
||||
/// Provider identifier to enable.
|
||||
provider: String,
|
||||
},
|
||||
/// Disable a provider in the configuration.
|
||||
Disable {
|
||||
/// Provider identifier to disable.
|
||||
provider: String,
|
||||
},
|
||||
/// Enable or disable the `web_search` tool exposure.
|
||||
Web(WebCommand),
|
||||
}
|
||||
|
||||
/// Arguments for the `owlen models` command.
|
||||
#[derive(Debug, Default, Args)]
|
||||
pub struct ModelsArgs {
|
||||
/// Restrict output to a specific provider.
|
||||
#[arg(long)]
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
/// Arguments for managing the `web_search` tool exposure.
|
||||
#[derive(Debug, Args)]
|
||||
pub struct WebCommand {
|
||||
/// Enable the `web_search` tool and allow remote lookups.
|
||||
#[arg(long, conflicts_with = "disable")]
|
||||
enable: bool,
|
||||
/// Disable the `web_search` tool to keep sessions local-only.
|
||||
#[arg(long, conflicts_with = "enable")]
|
||||
disable: bool,
|
||||
}
|
||||
|
||||
impl WebCommand {
|
||||
fn desired_state(&self) -> Option<bool> {
|
||||
if self.enable {
|
||||
Some(true)
|
||||
} else if self.disable {
|
||||
Some(false)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_providers_command(command: ProvidersCommand) -> Result<()> {
|
||||
match command {
|
||||
ProvidersCommand::List => list_providers(),
|
||||
ProvidersCommand::Status { provider } => status_providers(provider.as_deref()).await,
|
||||
ProvidersCommand::Enable { provider } => toggle_provider(&provider, true),
|
||||
ProvidersCommand::Disable { provider } => toggle_provider(&provider, false),
|
||||
ProvidersCommand::Web(args) => handle_web_command(args),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_models_command(args: ModelsArgs) -> Result<()> {
|
||||
list_models(args.provider.as_deref()).await
|
||||
}
|
||||
|
||||
fn list_providers() -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for (id, cfg) in &config.providers {
|
||||
let type_label = describe_provider_type(id, cfg);
|
||||
let auth_label = describe_auth(cfg, requires_auth(id, cfg));
|
||||
let enabled = if cfg.enabled { "yes" } else { "no" };
|
||||
let default = if id == &default_provider { "*" } else { "" };
|
||||
let base = cfg
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().to_string())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
|
||||
rows.push(ProviderListRow {
|
||||
id: id.to_string(),
|
||||
type_label,
|
||||
enabled: enabled.to_string(),
|
||||
default: default.to_string(),
|
||||
auth: auth_label,
|
||||
base_url: base,
|
||||
});
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let enabled_width = rows
|
||||
.iter()
|
||||
.map(|row| row.enabled.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Enabled".len());
|
||||
let default_width = rows
|
||||
.iter()
|
||||
.map(|row| row.default.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Default".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.type_label.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let auth_width = rows
|
||||
.iter()
|
||||
.map(|row| row.auth.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Auth".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} Base URL",
|
||||
"Provider",
|
||||
"Enabled",
|
||||
"Default",
|
||||
"Type",
|
||||
"Auth",
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} {}",
|
||||
row.id,
|
||||
row.enabled,
|
||||
row.default,
|
||||
row.type_label,
|
||||
row.auth,
|
||||
row.base_url,
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status_providers(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let health = manager.refresh_health().await;
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for record in records {
|
||||
let status = health.get(&record.id).copied();
|
||||
rows.push(ProviderStatusRow::from_record(record, status));
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
print_status_rows(&rows);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_models(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let models = manager
|
||||
.list_all_models()
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let statuses = manager.provider_statuses().await;
|
||||
|
||||
print_models(records, models, statuses);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_provider_filter(config: &Config, filter: Option<&str>) -> Result<()> {
|
||||
if let Some(filter) = filter
|
||||
&& !config.providers.contains_key(filter)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Provider '{}' is not defined in configuration.",
|
||||
filter
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_web_command(args: WebCommand) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let initial = web_tool_enabled(&config);
|
||||
|
||||
if let Some(desired) = args.desired_state() {
|
||||
apply_web_toggle(&mut config, desired);
|
||||
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||
|
||||
if initial == desired {
|
||||
println!(
|
||||
"Web search tool already {}.",
|
||||
if desired { "enabled" } else { "disabled" }
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Web search tool {}.",
|
||||
if desired { "enabled" } else { "disabled" }
|
||||
);
|
||||
}
|
||||
println!(
|
||||
"Remote search is {}.",
|
||||
if config.privacy.enable_remote_search {
|
||||
"enabled"
|
||||
} else {
|
||||
"disabled"
|
||||
}
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Web search tool is {}.",
|
||||
if initial { "enabled" } else { "disabled" }
|
||||
);
|
||||
println!(
|
||||
"Remote search is {}.",
|
||||
if config.privacy.enable_remote_search {
|
||||
"enabled"
|
||||
} else {
|
||||
"disabled"
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn apply_web_toggle(config: &mut Config, enabled: bool) {
|
||||
config.tools.web_search.enabled = enabled;
|
||||
config.privacy.enable_remote_search = enabled;
|
||||
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||
|
||||
if enabled {
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
fn web_tool_enabled(config: &Config) -> bool {
|
||||
config.tools.web_search.enabled && config.privacy.enable_remote_search
|
||||
}
|
||||
|
||||
fn toggle_provider(provider: &str, enable: bool) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let canonical = canonical_provider_id(provider);
|
||||
if canonical.is_empty() {
|
||||
return Err(anyhow!("Provider name cannot be empty."));
|
||||
}
|
||||
|
||||
let previous_default = config.general.default_provider.clone();
|
||||
let previous_fallback_enabled = config.providers.get("ollama_local").map(|cfg| cfg.enabled);
|
||||
|
||||
let previous_enabled;
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
previous_enabled = entry.enabled;
|
||||
if previous_enabled == enable {
|
||||
println!(
|
||||
"Provider '{}' is already {}.",
|
||||
canonical,
|
||||
if enable { "enabled" } else { "disabled" }
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
entry.enabled = enable;
|
||||
}
|
||||
|
||||
if !enable && config.general.default_provider == canonical {
|
||||
if let Some(candidate) = choose_fallback_provider(&config, &canonical) {
|
||||
config.general.default_provider = candidate.clone();
|
||||
println!(
|
||||
"Default provider set to '{}' because '{}' was disabled.",
|
||||
candidate, canonical
|
||||
);
|
||||
} else {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
entry.enabled = true;
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
println!(
|
||||
"Enabled 'ollama_local' and made it default because no other providers are active."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = config.validate() {
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
entry.enabled = previous_enabled;
|
||||
}
|
||||
config.general.default_provider = previous_default;
|
||||
if let Some(enabled) = previous_fallback_enabled
|
||||
&& let Some(entry) = config.providers.get_mut("ollama_local")
|
||||
{
|
||||
entry.enabled = enabled;
|
||||
}
|
||||
return Err(anyhow!(err));
|
||||
}
|
||||
|
||||
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||
|
||||
println!(
|
||||
"{} provider '{}'.",
|
||||
if enable { "Enabled" } else { "Disabled" },
|
||||
canonical
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn choose_fallback_provider(config: &Config, exclude: &str) -> Option<String> {
|
||||
if exclude != "ollama_local"
|
||||
&& let Some(cfg) = config.providers.get("ollama_local")
|
||||
&& cfg.enabled
|
||||
{
|
||||
return Some("ollama_local".to_string());
|
||||
}
|
||||
|
||||
let mut candidates: Vec<String> = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && id.as_str() != exclude)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
candidates.sort();
|
||||
candidates.into_iter().next()
|
||||
}
|
||||
|
||||
async fn register_enabled_providers(
|
||||
manager: &ProviderManager,
|
||||
config: &Config,
|
||||
filter: Option<&str>,
|
||||
) -> Result<Vec<ProviderRecord>> {
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
let mut records = Vec::new();
|
||||
|
||||
for (id, cfg) in &config.providers {
|
||||
if let Some(filter) = filter
|
||||
&& id != filter
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut record = ProviderRecord::from_config(id, cfg, id == &default_provider);
|
||||
if !cfg.enabled {
|
||||
records.push(record);
|
||||
continue;
|
||||
}
|
||||
|
||||
match instantiate_provider(id, cfg) {
|
||||
Ok(provider) => {
|
||||
let metadata = provider.metadata().clone();
|
||||
record.provider_type_label = provider_type_label(metadata.provider_type);
|
||||
record.requires_auth = metadata.requires_auth;
|
||||
record.metadata = Some(metadata);
|
||||
manager.register_provider(provider).await;
|
||||
}
|
||||
Err(err) => {
|
||||
record.registration_error = Some(err.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
records.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn instantiate_provider(id: &str, cfg: &ProviderConfig) -> Result<Arc<dyn ModelProvider>> {
|
||||
let kind = cfg.provider_type.trim().to_ascii_lowercase();
|
||||
if kind == "ollama" || id == "ollama_local" {
|
||||
let provider = OllamaLocalProvider::new(cfg.base_url.clone(), None, None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else if kind == "ollama_cloud" || id == "ollama_cloud" {
|
||||
let provider = OllamaCloudProvider::new(cfg.base_url.clone(), cfg.api_key.clone(), None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Provider '{}' uses unsupported type '{}'.",
|
||||
id,
|
||||
if kind.is_empty() {
|
||||
"unknown"
|
||||
} else {
|
||||
kind.as_str()
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn describe_provider_type(id: &str, cfg: &ProviderConfig) -> String {
|
||||
if cfg.provider_type.trim().eq_ignore_ascii_case("ollama") || id.ends_with("_local") {
|
||||
"Local".to_string()
|
||||
} else if cfg
|
||||
.provider_type
|
||||
.trim()
|
||||
.eq_ignore_ascii_case("ollama_cloud")
|
||||
|| id.contains("cloud")
|
||||
{
|
||||
"Cloud".to_string()
|
||||
} else {
|
||||
"Custom".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_auth(id: &str, cfg: &ProviderConfig) -> bool {
|
||||
cfg.api_key.is_some()
|
||||
|| cfg.api_key_env.is_some()
|
||||
|| matches!(id, "ollama_cloud" | "openai" | "anthropic")
|
||||
}
|
||||
|
||||
fn describe_auth(cfg: &ProviderConfig, required: bool) -> String {
|
||||
if let Some(env) = cfg
|
||||
.api_key_env
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
format!("env:{env}")
|
||||
} else if cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| !value.trim().is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"config".to_string()
|
||||
} else if required {
|
||||
"required".to_string()
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_id(raw: &str) -> String {
|
||||
let trimmed = raw.trim().to_ascii_lowercase();
|
||||
if trimmed.is_empty() {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
match trimmed.as_str() {
|
||||
"ollama" | "ollama-local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" | "ollama-cloud" => "ollama_cloud".to_string(),
|
||||
other => other.replace('-', "_"),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_type_label(provider_type: ProviderType) -> String {
|
||||
match provider_type {
|
||||
ProviderType::Local => "Local".to_string(),
|
||||
ProviderType::Cloud => "Cloud".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_status_strings(status: ProviderStatus) -> (&'static str, &'static str) {
|
||||
match status {
|
||||
ProviderStatus::Available => ("OK", "available"),
|
||||
ProviderStatus::Unavailable => ("ERR", "unavailable"),
|
||||
ProviderStatus::RequiresSetup => ("SETUP", "requires setup"),
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_rows(rows: &[ProviderStatusRow]) {
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.provider_type.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let status_width = rows
|
||||
.iter()
|
||||
.map(|row| row.indicator.len() + 1 + row.status_label.len())
|
||||
.max()
|
||||
.unwrap_or(6)
|
||||
.max("State".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} Details",
|
||||
"Provider",
|
||||
"Def",
|
||||
"Type",
|
||||
"State",
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
let def = if row.default_provider { "*" } else { "-" };
|
||||
let details = row.detail.as_deref().unwrap_or("-");
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} {}",
|
||||
row.id,
|
||||
def,
|
||||
row.provider_type,
|
||||
format!("{} {}", row.indicator, row.status_label),
|
||||
details,
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_models(
|
||||
records: Vec<ProviderRecord>,
|
||||
models: Vec<AnnotatedModelInfo>,
|
||||
statuses: HashMap<String, ProviderStatus>,
|
||||
) {
|
||||
let mut grouped: HashMap<String, Vec<AnnotatedModelInfo>> = HashMap::new();
|
||||
for info in models {
|
||||
grouped
|
||||
.entry(info.provider_id.clone())
|
||||
.or_default()
|
||||
.push(info);
|
||||
}
|
||||
|
||||
for record in records {
|
||||
let status = statuses.get(&record.id).copied().or_else(|| {
|
||||
if record.metadata.is_some() && record.registration_error.is_none() && record.enabled {
|
||||
Some(ProviderStatus::Unavailable)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let (indicator, label, status_value) = if !record.enabled {
|
||||
("-", "disabled", None)
|
||||
} else if record.registration_error.is_some() {
|
||||
("ERR", "error", None)
|
||||
} else if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
(indicator, label, Some(status))
|
||||
} else {
|
||||
("?", "unknown", None)
|
||||
};
|
||||
|
||||
let title = if record.default_provider {
|
||||
format!("{} (default)", record.id)
|
||||
} else {
|
||||
record.id.clone()
|
||||
};
|
||||
println!(
|
||||
"{} {} [{}] {}",
|
||||
indicator, title, record.provider_type_label, label
|
||||
);
|
||||
|
||||
if let Some(err) = &record.registration_error {
|
||||
println!(" error: {}", err);
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !record.enabled {
|
||||
println!(" provider disabled");
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entries) = grouped.get(&record.id) {
|
||||
let mut entries = entries.clone();
|
||||
entries.sort_by(|a, b| a.model.name.cmp(&b.model.name));
|
||||
if entries.is_empty() {
|
||||
println!(" (no models reported)");
|
||||
} else {
|
||||
for entry in entries {
|
||||
let mut line = format!(" - {}", entry.model.name);
|
||||
if let Some(description) = &entry.model.description
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
line.push_str(&format!(" — {}", description.trim()));
|
||||
}
|
||||
println!("{}", line);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" (no models reported)");
|
||||
}
|
||||
|
||||
if let Some(ProviderStatus::RequiresSetup) = status_value
|
||||
&& record.requires_auth
|
||||
{
|
||||
println!(" configure provider credentials or API key");
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderListRow {
|
||||
id: String,
|
||||
type_label: String,
|
||||
enabled: String,
|
||||
default: String,
|
||||
auth: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
struct ProviderRecord {
|
||||
id: String,
|
||||
enabled: bool,
|
||||
default_provider: bool,
|
||||
provider_type_label: String,
|
||||
requires_auth: bool,
|
||||
registration_error: Option<String>,
|
||||
metadata: Option<owlen_core::provider::ProviderMetadata>,
|
||||
}
|
||||
|
||||
impl ProviderRecord {
|
||||
fn from_config(id: &str, cfg: &ProviderConfig, default_provider: bool) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
enabled: cfg.enabled,
|
||||
default_provider,
|
||||
provider_type_label: describe_provider_type(id, cfg),
|
||||
requires_auth: requires_auth(id, cfg),
|
||||
registration_error: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderStatusRow {
|
||||
id: String,
|
||||
provider_type: String,
|
||||
default_provider: bool,
|
||||
indicator: String,
|
||||
status_label: String,
|
||||
detail: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderStatusRow {
|
||||
fn from_record(record: ProviderRecord, status: Option<ProviderStatus>) -> Self {
|
||||
if !record.enabled {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "-".to_string(),
|
||||
status_label: "disabled".to_string(),
|
||||
detail: None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(err) = record.registration_error {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "ERR".to_string(),
|
||||
status_label: "error".to_string(),
|
||||
detail: Some(err),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: indicator.to_string(),
|
||||
status_label: label.to_string(),
|
||||
detail: if matches!(status, ProviderStatus::RequiresSetup) && record.requires_auth {
|
||||
Some("credentials required".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "?".to_string(),
|
||||
status_label: "unknown".to_string(),
|
||||
detail: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn apply_web_toggle_updates_flags_and_allowed_tools() {
|
||||
let mut config = Config::default();
|
||||
config.privacy.enable_remote_search = false;
|
||||
config.tools.web_search.enabled = false;
|
||||
config.security.allowed_tools.clear();
|
||||
|
||||
apply_web_toggle(&mut config, true);
|
||||
assert!(config.tools.web_search.enabled);
|
||||
assert!(config.privacy.enable_remote_search);
|
||||
assert_eq!(
|
||||
1,
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
.count()
|
||||
);
|
||||
|
||||
apply_web_toggle(&mut config, false);
|
||||
assert!(!config.tools.web_search.enabled);
|
||||
assert!(!config.privacy.enable_remote_search);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_web_toggle_does_not_duplicate_allowed_entries() {
|
||||
let mut config = Config::default();
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.retain(|tool| !tool_name_matches(tool, WEB_SEARCH_TOOL_NAME));
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.push(WEB_SEARCH_TOOL_NAME.to_string());
|
||||
|
||||
apply_web_toggle(&mut config, true);
|
||||
apply_web_toggle(&mut config, true);
|
||||
|
||||
assert_eq!(
|
||||
1,
|
||||
config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.filter(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
.count()
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,203 +0,0 @@
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::automation::repo::{
|
||||
CommitTemplate, DiffCaptureMode, PullRequestContext, PullRequestReview, RepoAutomation,
|
||||
summarize_diff,
|
||||
};
|
||||
use owlen_core::github::{GithubClient, GithubConfig};
|
||||
|
||||
/// Subcommands for repository automation helpers (commit templates, PR reviews, workflows).
|
||||
#[allow(clippy::large_enum_variant)]
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum RepoCommand {
|
||||
/// Generate a conventional commit template from repository changes.
|
||||
CommitTemplate(CommitTemplateArgs),
|
||||
/// Produce a structured review for a pull request or local diff.
|
||||
Review(ReviewArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct CommitTemplateArgs {
|
||||
/// Repository path (defaults to current directory).
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub repo: Option<PathBuf>,
|
||||
/// Output format for the generated template.
|
||||
#[arg(long, value_enum, default_value_t = OutputFormat::Markdown)]
|
||||
pub format: OutputFormat,
|
||||
/// Include unstaged working tree changes instead of staged changes.
|
||||
#[arg(long)]
|
||||
pub working_tree: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct ReviewArgs {
|
||||
/// Repository path for local diff analysis.
|
||||
#[arg(long, value_name = "PATH")]
|
||||
pub repo: Option<PathBuf>,
|
||||
/// Base ref for local diff review (default: origin/main).
|
||||
#[arg(long)]
|
||||
pub base: Option<String>,
|
||||
/// Head ref for local diff review (default: HEAD).
|
||||
#[arg(long)]
|
||||
pub head: Option<String>,
|
||||
/// Owner of the GitHub repository.
|
||||
#[arg(long)]
|
||||
pub owner: Option<String>,
|
||||
/// Repository name on GitHub.
|
||||
#[arg(long = "repo")]
|
||||
pub repository: Option<String>,
|
||||
/// Pull request number to fetch from GitHub.
|
||||
#[arg(long)]
|
||||
pub number: Option<u64>,
|
||||
/// GitHub personal access token (falls back to environment variable).
|
||||
#[arg(long)]
|
||||
pub token: Option<String>,
|
||||
/// Environment variable used to resolve the GitHub token.
|
||||
#[arg(long, default_value = "GITHUB_TOKEN")]
|
||||
pub token_env: String,
|
||||
/// Custom GitHub API endpoint (for GitHub Enterprise).
|
||||
#[arg(long)]
|
||||
pub api_endpoint: Option<String>,
|
||||
/// Path to a diff file to analyse instead of hitting Git or GitHub.
|
||||
#[arg(long, value_name = "FILE")]
|
||||
pub diff_file: Option<PathBuf>,
|
||||
/// Output format for the review body.
|
||||
#[arg(long, value_enum, default_value_t = OutputFormat::Markdown)]
|
||||
pub format: OutputFormat,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, PartialEq, Eq)]
|
||||
pub enum OutputFormat {
|
||||
Text,
|
||||
Markdown,
|
||||
Json,
|
||||
}
|
||||
|
||||
pub async fn run_repo_command(command: RepoCommand) -> Result<()> {
|
||||
match command {
|
||||
RepoCommand::CommitTemplate(args) => handle_commit_template(args).await,
|
||||
RepoCommand::Review(args) => handle_review(args).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_commit_template(args: CommitTemplateArgs) -> Result<()> {
|
||||
let repo_hint = args.repo.clone().unwrap_or_else(|| PathBuf::from("."));
|
||||
let automation = RepoAutomation::from_path(&repo_hint)?;
|
||||
let mode = if args.working_tree {
|
||||
DiffCaptureMode::WorkingTree
|
||||
} else {
|
||||
DiffCaptureMode::Staged
|
||||
};
|
||||
let template = automation.generate_commit_template(mode)?;
|
||||
emit_commit_template(&template, args.format);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_review(args: ReviewArgs) -> Result<()> {
|
||||
if let Some(number) = args.number {
|
||||
let owner = args
|
||||
.owner
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow!("--owner is required when --number is provided"))?;
|
||||
let repo = args
|
||||
.repository
|
||||
.as_deref()
|
||||
.ok_or_else(|| anyhow!("--repo is required when --number is provided"))?;
|
||||
let token = args
|
||||
.token
|
||||
.or_else(|| env::var(&args.token_env).ok())
|
||||
.filter(|value| !value.trim().is_empty());
|
||||
let client = GithubClient::new(GithubConfig {
|
||||
token,
|
||||
api_endpoint: args.api_endpoint.clone(),
|
||||
})?;
|
||||
let details = client.pull_request(owner, repo, number).await?;
|
||||
let review = PullRequestReview::from_diff(details.context, &details.diff);
|
||||
emit_review_output(review, args.format);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if let Some(path) = args.diff_file.as_ref() {
|
||||
let diff = std::fs::read_to_string(path)
|
||||
.with_context(|| format!("Failed to read diff file {}", path.display()))?;
|
||||
let stats = summarize_diff(&diff);
|
||||
let diff_label = path
|
||||
.file_name()
|
||||
.and_then(|s| s.to_str())
|
||||
.map(|s| s.to_string())
|
||||
.unwrap_or_else(|| path.display().to_string());
|
||||
let context = PullRequestContext {
|
||||
title: format!("Review for diff from {}", diff_label),
|
||||
body: None,
|
||||
author: None,
|
||||
base_branch: args
|
||||
.base
|
||||
.clone()
|
||||
.unwrap_or_else(|| "(unknown base)".to_string()),
|
||||
head_branch: args.head.clone().unwrap_or_else(|| "(diff)".to_string()),
|
||||
additions: stats.additions as u64,
|
||||
deletions: stats.deletions as u64,
|
||||
changed_files: stats.files as u64,
|
||||
html_url: None,
|
||||
};
|
||||
let review = PullRequestReview::from_diff(context, &diff);
|
||||
emit_review_output(review, args.format);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let repo_hint = args.repo.clone().unwrap_or_else(|| PathBuf::from("."));
|
||||
let automation = RepoAutomation::from_path(&repo_hint)?;
|
||||
let review = automation.generate_pr_review(args.base.as_deref(), args.head.as_deref())?;
|
||||
emit_review_output(review, args.format);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn emit_commit_template(template: &CommitTemplate, format: OutputFormat) {
|
||||
match format {
|
||||
OutputFormat::Markdown => {
|
||||
println!("{}", template.render_markdown());
|
||||
}
|
||||
OutputFormat::Text => {
|
||||
let markdown = template.render_markdown();
|
||||
for line in markdown.lines() {
|
||||
println!("{}", line.trim_start_matches('-').trim());
|
||||
}
|
||||
}
|
||||
OutputFormat::Json => match serde_json::to_string_pretty(template) {
|
||||
Ok(json) => println!("{}", json),
|
||||
Err(err) => eprintln!("Failed to encode template as JSON: {}", err),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn emit_review_output(review: PullRequestReview, format: OutputFormat) {
|
||||
match format {
|
||||
OutputFormat::Markdown => println!("{}", review.render_markdown()),
|
||||
OutputFormat::Text => {
|
||||
println!("{}", review.summary);
|
||||
for highlight in review.highlights {
|
||||
println!("* {}", highlight);
|
||||
}
|
||||
if !review.findings.is_empty() {
|
||||
println!("Findings:");
|
||||
for finding in review.findings {
|
||||
println!(" - [{}] {}", finding.severity, finding.message);
|
||||
}
|
||||
}
|
||||
if !review.checklist.is_empty() {
|
||||
println!("Checklist:");
|
||||
for item in review.checklist {
|
||||
let mark = if item.completed { "x" } else { " " };
|
||||
println!(" - [{}] {}", mark, item.label);
|
||||
}
|
||||
}
|
||||
}
|
||||
OutputFormat::Json => match serde_json::to_string_pretty(&review) {
|
||||
Ok(json) => println!("{}", json),
|
||||
Err(err) => eprintln!("Failed to encode review as JSON: {}", err),
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use clap::{Subcommand, ValueEnum};
|
||||
use owlen_core::config::ApprovalMode;
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
/// Security-related configuration commands.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum SecurityCommand {
|
||||
/// Display the current approval mode.
|
||||
Show,
|
||||
/// Set the approval mode (auto, read-only, plan-first).
|
||||
Approval {
|
||||
/// Approval policy to apply to new sessions.
|
||||
#[arg(value_enum)]
|
||||
mode: ApprovalModeArg,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum)]
|
||||
pub enum ApprovalModeArg {
|
||||
Auto,
|
||||
#[clap(name = "read-only")]
|
||||
ReadOnly,
|
||||
#[clap(name = "plan-first")]
|
||||
PlanFirst,
|
||||
}
|
||||
|
||||
impl From<ApprovalModeArg> for ApprovalMode {
|
||||
fn from(value: ApprovalModeArg) -> Self {
|
||||
match value {
|
||||
ApprovalModeArg::Auto => ApprovalMode::Auto,
|
||||
ApprovalModeArg::ReadOnly => ApprovalMode::ReadOnly,
|
||||
ApprovalModeArg::PlanFirst => ApprovalMode::PlanFirst,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_security_command(command: SecurityCommand) -> Result<()> {
|
||||
match command {
|
||||
SecurityCommand::Show => show_approval_mode(),
|
||||
SecurityCommand::Approval { mode } => set_approval_mode(mode.into()),
|
||||
}
|
||||
}
|
||||
|
||||
fn show_approval_mode() -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
println!(
|
||||
"Current approval mode: {}",
|
||||
config.security.approval_mode.as_str()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_approval_mode(mode: ApprovalMode) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.security.approval_mode = mode;
|
||||
config.validate()?;
|
||||
tui_config::save_config(&config)?;
|
||||
println!("Set approval mode to {}.", mode.as_str());
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,110 +0,0 @@
|
||||
use std::str::FromStr;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand};
|
||||
use owlen_core::mcp::presets::{self, PresetTier};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
/// CLI entry points for managing MCP tool presets.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ToolsCommand {
|
||||
/// Install a reference MCP tool preset.
|
||||
Install(InstallArgs),
|
||||
/// Audit the current MCP servers against a preset.
|
||||
Audit(AuditArgs),
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct InstallArgs {
|
||||
/// Preset tier to install (standard, extended, full).
|
||||
#[arg(value_parser = parse_preset)]
|
||||
pub preset: PresetTier,
|
||||
/// Remove MCP servers not included in the preset.
|
||||
#[arg(long)]
|
||||
pub prune: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AuditArgs {
|
||||
/// Preset tier to audit (defaults to full).
|
||||
#[arg(value_parser = parse_preset)]
|
||||
pub preset: Option<PresetTier>,
|
||||
}
|
||||
|
||||
pub fn run_tools_command(command: ToolsCommand) -> Result<()> {
|
||||
match command {
|
||||
ToolsCommand::Install(args) => install_preset(args),
|
||||
ToolsCommand::Audit(args) => audit_preset(args),
|
||||
}
|
||||
}
|
||||
|
||||
fn install_preset(args: InstallArgs) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let report = presets::apply_preset(&mut config, args.preset, args.prune)?;
|
||||
tui_config::save_config(&config)?;
|
||||
|
||||
println!(
|
||||
"Installed '{}' preset (prune = {}).",
|
||||
args.preset.as_str(),
|
||||
args.prune
|
||||
);
|
||||
|
||||
if !report.added.is_empty() {
|
||||
println!(" added: {}", report.added.join(", "));
|
||||
}
|
||||
if !report.updated.is_empty() {
|
||||
println!(" updated: {}", report.updated.join(", "));
|
||||
}
|
||||
if !report.removed.is_empty() {
|
||||
println!(" removed: {}", report.removed.join(", "));
|
||||
}
|
||||
|
||||
if report.added.is_empty() && report.updated.is_empty() && report.removed.is_empty() {
|
||||
println!(" no changes were necessary.");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn audit_preset(args: AuditArgs) -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
let preset = args.preset.unwrap_or(PresetTier::Full);
|
||||
let report = presets::audit_preset(&config, preset);
|
||||
|
||||
println!("Audit for '{}' preset:", preset.as_str());
|
||||
if report.missing.is_empty() && report.mismatched.is_empty() && report.extra.is_empty() {
|
||||
println!(" configuration already matches this preset.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !report.missing.is_empty() {
|
||||
println!(" missing connectors:");
|
||||
for missing in report.missing {
|
||||
println!(" - {}", missing.name);
|
||||
}
|
||||
}
|
||||
|
||||
if !report.mismatched.is_empty() {
|
||||
println!(" mismatched connectors:");
|
||||
for (expected, actual) in report.mismatched {
|
||||
println!(
|
||||
" - {} (expected command '{}', found '{}')",
|
||||
expected.name, expected.command, actual.command
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !report.extra.is_empty() {
|
||||
println!(" extra connectors:");
|
||||
for extra in report.extra {
|
||||
println!(" - {}", extra.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_preset(value: &str) -> Result<PresetTier> {
|
||||
PresetTier::from_str(value)
|
||||
.map_err(|_| anyhow!("Unknown preset '{value}'. Use one of: standard, extended, full."))
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! Library portion of the `owlen-cli` crate.
|
||||
//!
|
||||
//! It currently only re‑exports the `agent` module used by the standalone
|
||||
//! `owlen-agent` binary. Additional shared functionality can be added here in
|
||||
//! the future.
|
||||
|
||||
// Re-export agent module from owlen-core
|
||||
pub use owlen_core::agent;
|
||||
@@ -1,489 +0,0 @@
|
||||
//! OWLEN CLI - Chat TUI client
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Parser, Subcommand};
|
||||
use commands::{
|
||||
cloud::{CloudCommand, run_cloud_command},
|
||||
providers::{ModelsArgs, ProvidersCommand, run_models_command, run_providers_command},
|
||||
repo::{RepoCommand, run_repo_command},
|
||||
security::{SecurityCommand, run_security_command},
|
||||
tools::{ToolsCommand, run_tools_command},
|
||||
};
|
||||
use mcp::{McpCommand, run_mcp_command};
|
||||
use owlen_core::config::{
|
||||
self as core_config, Config, DEFAULT_OLLAMA_CLOUD_HOURLY_QUOTA,
|
||||
DEFAULT_OLLAMA_CLOUD_WEEKLY_QUOTA, DEFAULT_PROVIDER_CONTEXT_WINDOW_TOKENS,
|
||||
DEFAULT_PROVIDER_LIST_TTL_SECS, LEGACY_OLLAMA_CLOUD_API_KEY_ENV, LEGACY_OLLAMA_CLOUD_BASE_URL,
|
||||
LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV, McpMode, OLLAMA_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY,
|
||||
};
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
use serde_json::{Number as JsonNumber, Value as JsonValue};
|
||||
use std::env;
|
||||
|
||||
/// Owlen - Terminal UI for LLM chat
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "owlen")]
|
||||
#[command(about = "Terminal UI for LLM chat via MCP", long_about = None)]
|
||||
struct Args {
|
||||
/// Start in code mode (enables all tools)
|
||||
#[arg(long, short = 'c')]
|
||||
code: bool,
|
||||
/// Disable automatic transcript compression for this session
|
||||
#[arg(long)]
|
||||
no_auto_compress: bool,
|
||||
#[command(subcommand)]
|
||||
command: Option<OwlenCommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum OwlenCommand {
|
||||
/// Inspect or upgrade configuration files
|
||||
#[command(subcommand)]
|
||||
Config(ConfigCommand),
|
||||
/// Manage Ollama Cloud credentials
|
||||
#[command(subcommand)]
|
||||
Cloud(CloudCommand),
|
||||
/// Manage model providers
|
||||
#[command(subcommand)]
|
||||
Providers(ProvidersCommand),
|
||||
/// List models exposed by configured providers
|
||||
Models(ModelsArgs),
|
||||
/// Manage MCP server registrations
|
||||
#[command(subcommand)]
|
||||
Mcp(McpCommand),
|
||||
/// Manage MCP tool presets
|
||||
#[command(subcommand)]
|
||||
Tools(ToolsCommand),
|
||||
/// Configure security and approval policies
|
||||
#[command(subcommand)]
|
||||
Security(SecurityCommand),
|
||||
/// Repository automation helpers (commit templates, PR reviews)
|
||||
#[command(subcommand)]
|
||||
Repo(RepoCommand),
|
||||
/// Show manual steps for updating Owlen to the latest revision
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum ConfigCommand {
|
||||
/// Automatically upgrade legacy configuration values and ensure validity
|
||||
Doctor,
|
||||
/// Print the resolved configuration file path
|
||||
Path,
|
||||
/// Create a fresh configuration file using the latest defaults
|
||||
Init {
|
||||
/// Overwrite the existing configuration if present.
|
||||
#[arg(long)]
|
||||
force: bool,
|
||||
},
|
||||
}
|
||||
|
||||
async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||
match command {
|
||||
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||
OwlenCommand::Cloud(cloud_cmd) => run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Providers(provider_cmd) => run_providers_command(provider_cmd).await,
|
||||
OwlenCommand::Models(args) => run_models_command(args).await,
|
||||
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||
OwlenCommand::Tools(tools_cmd) => run_tools_command(tools_cmd),
|
||||
OwlenCommand::Security(sec_cmd) => run_security_command(sec_cmd),
|
||||
OwlenCommand::Repo(repo_cmd) => run_repo_command(repo_cmd).await,
|
||||
OwlenCommand::Upgrade => {
|
||||
println!(
|
||||
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||
);
|
||||
println!(
|
||||
"If you installed from the AUR, use your package manager (e.g., yay -S owlen-git)."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_config_command(command: ConfigCommand) -> Result<()> {
|
||||
match command {
|
||||
ConfigCommand::Doctor => run_config_doctor(),
|
||||
ConfigCommand::Path => {
|
||||
let path = core_config::default_config_path();
|
||||
println!("{}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
ConfigCommand::Init { force } => run_config_init(force),
|
||||
}
|
||||
}
|
||||
|
||||
fn run_config_init(force: bool) -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
if config_path.exists() && !force {
|
||||
return Err(anyhow!(
|
||||
"Configuration already exists at {}. Re-run with --force to overwrite.",
|
||||
config_path.display()
|
||||
));
|
||||
}
|
||||
|
||||
let mut config = Config::default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
config.validate()?;
|
||||
|
||||
config::save_config(&config)?;
|
||||
println!("Wrote default configuration to {}.", config_path.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn run_config_doctor() -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
let existed = config_path.exists();
|
||||
let mut config = config::try_load_config().unwrap_or_default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
let mut changes = Vec::new();
|
||||
let mut warnings = Vec::new();
|
||||
|
||||
if !existed {
|
||||
changes.push("created configuration file from defaults".to_string());
|
||||
}
|
||||
|
||||
if config.provider(&config.general.default_provider).is_none() {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama_local'".to_string());
|
||||
}
|
||||
|
||||
for key in ["ollama_local", "ollama_cloud", "openai", "anthropic"] {
|
||||
if !config.providers.contains_key(key) {
|
||||
core_config::ensure_provider_config_mut(&mut config, key);
|
||||
changes.push(format!("added default configuration for provider '{key}'"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(local) = config.providers.get_mut("ollama_local") {
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut local.extra,
|
||||
"list_ttl_secs",
|
||||
DEFAULT_PROVIDER_LIST_TTL_SECS,
|
||||
) {
|
||||
changes.push("added providers.ollama_local.list_ttl_secs (default 60)".to_string());
|
||||
}
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut local.extra,
|
||||
"default_context_window",
|
||||
u64::from(DEFAULT_PROVIDER_CONTEXT_WINDOW_TOKENS),
|
||||
) {
|
||||
changes.push(format!(
|
||||
"added providers.ollama_local.default_context_window (default {})",
|
||||
DEFAULT_PROVIDER_CONTEXT_WINDOW_TOKENS
|
||||
));
|
||||
}
|
||||
if local.provider_type.trim().is_empty() || local.provider_type != "ollama" {
|
||||
local.provider_type = "ollama".to_string();
|
||||
changes.push("normalised providers.ollama_local.provider_type to 'ollama'".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(cloud) = config.providers.get_mut("ollama_cloud") {
|
||||
if cloud.provider_type.trim().is_empty()
|
||||
|| !cloud.provider_type.eq_ignore_ascii_case("ollama_cloud")
|
||||
{
|
||||
cloud.provider_type = "ollama_cloud".to_string();
|
||||
changes.push(
|
||||
"normalised providers.ollama_cloud.provider_type to 'ollama_cloud'".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let previous_base_url = cloud.base_url.clone();
|
||||
match cloud
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim_end_matches('/'))
|
||||
{
|
||||
None => {
|
||||
cloud.base_url = Some(OLLAMA_CLOUD_BASE_URL.to_string());
|
||||
}
|
||||
Some(current) if current.eq_ignore_ascii_case(LEGACY_OLLAMA_CLOUD_BASE_URL) => {
|
||||
cloud.base_url = Some(OLLAMA_CLOUD_BASE_URL.to_string());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
if cloud.base_url != previous_base_url {
|
||||
changes.push(
|
||||
"normalised providers.ollama_cloud.base_url to https://ollama.com".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let original_api_key_env = cloud.api_key_env.clone();
|
||||
let needs_env_update = cloud
|
||||
.api_key_env
|
||||
.as_ref()
|
||||
.map(|value| value.trim().is_empty())
|
||||
.unwrap_or(true);
|
||||
if needs_env_update {
|
||||
cloud.api_key_env = Some(OLLAMA_API_KEY_ENV.to_string());
|
||||
}
|
||||
if let Some(ref value) = original_api_key_env
|
||||
&& (value.eq_ignore_ascii_case(LEGACY_OLLAMA_CLOUD_API_KEY_ENV)
|
||||
|| value.eq_ignore_ascii_case(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV))
|
||||
{
|
||||
cloud.api_key_env = Some(OLLAMA_API_KEY_ENV.to_string());
|
||||
}
|
||||
if cloud.api_key_env != original_api_key_env {
|
||||
changes
|
||||
.push("updated providers.ollama_cloud.api_key_env to 'OLLAMA_API_KEY'".to_string());
|
||||
}
|
||||
|
||||
if ensure_string_extra_with_change(
|
||||
&mut cloud.extra,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY,
|
||||
OLLAMA_CLOUD_BASE_URL,
|
||||
) {
|
||||
changes.push(
|
||||
"added providers.ollama_cloud.extra.cloud_endpoint (default https://ollama.com)"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut cloud.extra,
|
||||
"hourly_quota_tokens",
|
||||
DEFAULT_OLLAMA_CLOUD_HOURLY_QUOTA,
|
||||
) {
|
||||
changes.push(format!(
|
||||
"added providers.ollama_cloud.hourly_quota_tokens (default {})",
|
||||
DEFAULT_OLLAMA_CLOUD_HOURLY_QUOTA
|
||||
));
|
||||
}
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut cloud.extra,
|
||||
"weekly_quota_tokens",
|
||||
DEFAULT_OLLAMA_CLOUD_WEEKLY_QUOTA,
|
||||
) {
|
||||
changes.push(format!(
|
||||
"added providers.ollama_cloud.weekly_quota_tokens (default {})",
|
||||
DEFAULT_OLLAMA_CLOUD_WEEKLY_QUOTA
|
||||
));
|
||||
}
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut cloud.extra,
|
||||
"list_ttl_secs",
|
||||
DEFAULT_PROVIDER_LIST_TTL_SECS,
|
||||
) {
|
||||
changes.push("added providers.ollama_cloud.list_ttl_secs (default 60)".to_string());
|
||||
}
|
||||
if ensure_numeric_extra_with_change(
|
||||
&mut cloud.extra,
|
||||
"default_context_window",
|
||||
u64::from(DEFAULT_PROVIDER_CONTEXT_WINDOW_TOKENS),
|
||||
) {
|
||||
changes.push(format!(
|
||||
"added providers.ollama_cloud.default_context_window (default {})",
|
||||
DEFAULT_PROVIDER_CONTEXT_WINDOW_TOKENS
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let canonical_env = env::var(OLLAMA_API_KEY_ENV)
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty());
|
||||
let legacy_env = env::var(LEGACY_OLLAMA_CLOUD_API_KEY_ENV)
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty());
|
||||
let legacy_alt_env = env::var(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV)
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty());
|
||||
|
||||
if canonical_env.is_some() {
|
||||
if legacy_env.is_some() {
|
||||
warnings.push(format!(
|
||||
"Both {OLLAMA_API_KEY_ENV} and {LEGACY_OLLAMA_CLOUD_API_KEY_ENV} are set; Owlen will prefer {OLLAMA_API_KEY_ENV}."
|
||||
));
|
||||
}
|
||||
if legacy_alt_env.is_some() {
|
||||
warnings.push(format!(
|
||||
"Both {OLLAMA_API_KEY_ENV} and {LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV} are set; Owlen will prefer {OLLAMA_API_KEY_ENV}."
|
||||
));
|
||||
}
|
||||
} else {
|
||||
if legacy_env.is_some() {
|
||||
warnings.push(format!(
|
||||
"Legacy environment variable {LEGACY_OLLAMA_CLOUD_API_KEY_ENV} is set. Rename it to {OLLAMA_API_KEY_ENV} to match the latest configuration schema."
|
||||
));
|
||||
}
|
||||
if legacy_alt_env.is_some() {
|
||||
warnings.push(format!(
|
||||
"Legacy environment variable {LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV} is set. Rename it to {OLLAMA_API_KEY_ENV} to match the latest configuration schema."
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let mut ensure_default_enabled = true;
|
||||
|
||||
if !config.providers.values().any(|cfg| cfg.enabled) {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push("no providers were enabled; enabled 'ollama_local'".to_string());
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because no providers were enabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
ensure_default_enabled = false;
|
||||
}
|
||||
|
||||
if ensure_default_enabled {
|
||||
let default_id = config.general.default_provider.clone();
|
||||
if let Some(default_cfg) = config.providers.get(&default_id) && !default_cfg.enabled {
|
||||
if let Some(new_default) = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && *id != &default_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.min()
|
||||
{
|
||||
config.general.default_provider = new_default.clone();
|
||||
changes.push(format!(
|
||||
"default provider '{default_id}' was disabled; switched default to '{new_default}'"
|
||||
));
|
||||
} else {
|
||||
let entry =
|
||||
core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push(
|
||||
"enabled 'ollama_local' because default provider was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because previous default was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match config.mcp.mode {
|
||||
McpMode::Legacy => {
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config.mcp.warn_on_legacy = true;
|
||||
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||
}
|
||||
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"downgraded remote-only configuration to remote_preferred because no servers are defined"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
McpMode::RemotePreferred
|
||||
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||
{
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
config.validate()?;
|
||||
config::save_config(&config)?;
|
||||
|
||||
if changes.is_empty() {
|
||||
println!(
|
||||
"Configuration already up to date: {}",
|
||||
config_path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Updated {}:", config_path.display());
|
||||
for change in changes {
|
||||
println!(" - {change}");
|
||||
}
|
||||
}
|
||||
|
||||
if !warnings.is_empty() {
|
||||
println!("Warnings:");
|
||||
for warning in warnings {
|
||||
println!(" - {warning}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_numeric_extra_with_change(
|
||||
extra: &mut std::collections::HashMap<String, JsonValue>,
|
||||
key: &str,
|
||||
default_value: u64,
|
||||
) -> bool {
|
||||
match extra.get_mut(key) {
|
||||
Some(existing) => {
|
||||
if existing.as_u64().is_some() {
|
||||
false
|
||||
} else {
|
||||
*existing = JsonValue::Number(JsonNumber::from(default_value));
|
||||
true
|
||||
}
|
||||
}
|
||||
None => {
|
||||
extra.insert(
|
||||
key.to_string(),
|
||||
JsonValue::Number(JsonNumber::from(default_value)),
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn ensure_string_extra_with_change(
|
||||
extra: &mut std::collections::HashMap<String, JsonValue>,
|
||||
key: &str,
|
||||
default_value: &str,
|
||||
) -> bool {
|
||||
match extra.get_mut(key) {
|
||||
Some(existing) => match existing.as_str() {
|
||||
Some(value) if !value.trim().is_empty() => false,
|
||||
_ => {
|
||||
*existing = JsonValue::String(default_value.to_string());
|
||||
true
|
||||
}
|
||||
},
|
||||
None => {
|
||||
extra.insert(
|
||||
key.to_string(),
|
||||
JsonValue::String(default_value.to_string()),
|
||||
);
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
// Parse command-line arguments
|
||||
let Args {
|
||||
code,
|
||||
command,
|
||||
no_auto_compress,
|
||||
} = Args::parse();
|
||||
if let Some(command) = command {
|
||||
return run_command(command).await;
|
||||
}
|
||||
let initial_mode = if code { Mode::Code } else { Mode::Chat };
|
||||
bootstrap::launch(
|
||||
initial_mode,
|
||||
bootstrap::LaunchOptions {
|
||||
disable_auto_compress: no_auto_compress,
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -1,260 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum McpCommand {
|
||||
/// Add or update an MCP server in the selected scope
|
||||
Add(AddArgs),
|
||||
/// List MCP servers across scopes
|
||||
List(ListArgs),
|
||||
/// Remove an MCP server from a scope
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||
match command {
|
||||
McpCommand::Add(args) => handle_add(args),
|
||||
McpCommand::List(args) => handle_list(args),
|
||||
McpCommand::Remove(args) => handle_remove(args),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
pub enum ScopeArg {
|
||||
User,
|
||||
#[default]
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
impl From<ScopeArg> for McpConfigScope {
|
||||
fn from(value: ScopeArg) -> Self {
|
||||
match value {
|
||||
ScopeArg::User => McpConfigScope::User,
|
||||
ScopeArg::Project => McpConfigScope::Project,
|
||||
ScopeArg::Local => McpConfigScope::Local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AddArgs {
|
||||
/// Logical name used to reference the server
|
||||
pub name: String,
|
||||
/// Command or endpoint invoked for the server
|
||||
pub command: String,
|
||||
/// Transport mechanism (stdio, http, websocket)
|
||||
#[arg(long, default_value = "stdio")]
|
||||
pub transport: String,
|
||||
/// Configuration scope to write the server into
|
||||
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||
pub scope: ScopeArg,
|
||||
/// Environment variables (KEY=VALUE) passed to the server process
|
||||
#[arg(long = "env")]
|
||||
pub env: Vec<String>,
|
||||
/// Additional arguments appended when launching the server
|
||||
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Default)]
|
||||
pub struct ListArgs {
|
||||
/// Restrict output to a specific configuration scope
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
/// Display only the effective servers (after precedence resolution)
|
||||
#[arg(long)]
|
||||
pub effective_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the server to remove
|
||||
pub name: String,
|
||||
/// Optional explicit scope to remove from
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
}
|
||||
|
||||
fn handle_add(args: AddArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope: McpConfigScope = args.scope.into();
|
||||
let mut env_map = HashMap::new();
|
||||
for pair in &args.env {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||
if key.trim().is_empty() {
|
||||
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||
}
|
||||
env_map.insert(key.trim().to_string(), value.to_string());
|
||||
}
|
||||
|
||||
let server = McpServerConfig {
|
||||
name: args.name.clone(),
|
||||
command: args.command.clone(),
|
||||
args: args.args.clone(),
|
||||
transport: args.transport.to_lowercase(),
|
||||
env: env_map,
|
||||
oauth: None,
|
||||
rpc_timeout_secs: None,
|
||||
};
|
||||
|
||||
config.add_mcp_server(scope, server.clone(), None)?;
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope ({})",
|
||||
server.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope.",
|
||||
server.name, scope
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_list(args: ListArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
|
||||
let scoped = config.scoped_mcp_servers();
|
||||
if scoped.is_empty() {
|
||||
println!("No MCP servers configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter_scope = args.scope.map(|scope| scope.into());
|
||||
let effective = config.effective_mcp_servers();
|
||||
let mut active = HashSet::new();
|
||||
for server in effective {
|
||||
active.insert((
|
||||
server.name.clone(),
|
||||
server.command.clone(),
|
||||
server.transport.to_lowercase(),
|
||||
));
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||
"", "Scope", "Name", "Transport"
|
||||
);
|
||||
for entry in scoped {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||
let key = (
|
||||
entry.config.name.clone(),
|
||||
entry.config.command.clone(),
|
||||
entry.config.transport.to_lowercase(),
|
||||
);
|
||||
let marker = if active.contains(&key) { "*" } else { " " };
|
||||
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(
|
||||
"{} {:<8} {:<20} {:<10} {}",
|
||||
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||
);
|
||||
}
|
||||
|
||||
let scoped_resources = config.scoped_mcp_resources();
|
||||
if !scoped_resources.is_empty() {
|
||||
println!();
|
||||
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||
let effective_keys: HashSet<(String, String)> = config
|
||||
.effective_mcp_resources()
|
||||
.iter()
|
||||
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||
.collect();
|
||||
|
||||
for entry in scoped_resources {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
let marker = if effective_keys.contains(&key) {
|
||||
"*"
|
||||
} else {
|
||||
" "
|
||||
};
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||
|
||||
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope_hint = args.scope.map(|scope| scope.into());
|
||||
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||
|
||||
match result {
|
||||
Some(scope) => {
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Removed MCP server '{}' from {} scope ({})",
|
||||
args.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
println!("No MCP server named '{}' was found.", args.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||
if args.is_empty() {
|
||||
command.to_string()
|
||||
} else {
|
||||
format!("{} {}", command, args.join(" "))
|
||||
}
|
||||
}
|
||||
@@ -1,275 +0,0 @@
|
||||
//! Integration tests for the ReAct agent loop functionality.
|
||||
//!
|
||||
//! These tests verify that the agent executor correctly:
|
||||
//! - Parses ReAct formatted responses
|
||||
//! - Executes tool calls
|
||||
//! - Handles multi-step workflows
|
||||
//! - Recovers from errors
|
||||
//! - Respects iteration limits
|
||||
|
||||
use owlen_cli::agent::{AgentConfig, AgentExecutor, LlmResponse};
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
use owlen_core::tools::WEB_SEARCH_TOOL_NAME;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_react_parsing_tool_call() {
|
||||
let executor = create_test_executor().await;
|
||||
|
||||
// Test parsing a tool call with JSON arguments
|
||||
let text = "THOUGHT: I should search for information\nACTION: web_search\nACTION_INPUT: {\"query\": \"rust async programming\"}\n";
|
||||
|
||||
let result = executor.parse_response(text);
|
||||
|
||||
match result {
|
||||
Ok(LlmResponse::ToolCall {
|
||||
thought,
|
||||
tool_name,
|
||||
arguments,
|
||||
}) => {
|
||||
assert_eq!(thought, "I should search for information");
|
||||
assert_eq!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME);
|
||||
assert_eq!(arguments["query"], "rust async programming");
|
||||
}
|
||||
other => panic!("Expected ToolCall, got: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_react_parsing_final_answer() {
|
||||
let executor = create_test_executor().await;
|
||||
|
||||
let text = "THOUGHT: I have enough information now\nFINAL_ANSWER: The answer is 42\n";
|
||||
|
||||
let result = executor.parse_response(text);
|
||||
|
||||
match result {
|
||||
Ok(LlmResponse::FinalAnswer { thought, answer }) => {
|
||||
assert_eq!(thought, "I have enough information now");
|
||||
assert_eq!(answer, "The answer is 42");
|
||||
}
|
||||
other => panic!("Expected FinalAnswer, got: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_react_parsing_with_multiline_thought() {
|
||||
let executor = create_test_executor().await;
|
||||
|
||||
let text = "THOUGHT: This is a complex\nmulti-line thought\nACTION: list_files\nACTION_INPUT: {\"path\": \".\"}\n";
|
||||
|
||||
let result = executor.parse_response(text);
|
||||
|
||||
// The regex currently only captures until first newline
|
||||
// This test documents current behavior
|
||||
match result {
|
||||
Ok(LlmResponse::ToolCall { thought, .. }) => {
|
||||
// Regex pattern stops at first \n after THOUGHT:
|
||||
assert!(thought.contains("This is a complex"));
|
||||
}
|
||||
other => panic!("Expected ToolCall, got: {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires MCP LLM server to be running
|
||||
async fn test_agent_single_tool_scenario() {
|
||||
// This test requires a running MCP LLM server (which wraps Ollama)
|
||||
let provider = Arc::new(RemoteMcpClient::new().await.unwrap());
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: 5,
|
||||
model: "llama3.2".to_string(),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: None,
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
|
||||
// Simple query that should complete in one tool call
|
||||
let result = executor
|
||||
.run("List files in the current directory".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(agent_result) => {
|
||||
assert!(
|
||||
!agent_result.answer.is_empty(),
|
||||
"Answer should not be empty"
|
||||
);
|
||||
println!("Agent answer: {}", agent_result.answer);
|
||||
}
|
||||
Err(e) => {
|
||||
// It's okay if this fails due to LLM not following format
|
||||
println!("Agent test skipped: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires Ollama to be running
|
||||
async fn test_agent_multi_step_workflow() {
|
||||
// Test a query that requires multiple tool calls
|
||||
let provider = Arc::new(RemoteMcpClient::new().await.unwrap());
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: 10,
|
||||
model: "llama3.2".to_string(),
|
||||
temperature: Some(0.5), // Lower temperature for more consistent behavior
|
||||
max_tokens: None,
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
|
||||
// Query requiring multiple steps: list -> read -> analyze
|
||||
let result = executor
|
||||
.run("Find all Rust files and tell me which one contains 'Agent'".to_string())
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(agent_result) => {
|
||||
assert!(!agent_result.answer.is_empty());
|
||||
println!("Multi-step answer: {:?}", agent_result);
|
||||
}
|
||||
Err(e) => {
|
||||
println!("Multi-step test skipped: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires Ollama
|
||||
async fn test_agent_iteration_limit() {
|
||||
let provider = Arc::new(RemoteMcpClient::new().await.unwrap());
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: 2, // Very low limit to test enforcement
|
||||
model: "llama3.2".to_string(),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: None,
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
|
||||
// Complex query that would require many iterations
|
||||
let result = executor
|
||||
.run("Perform an exhaustive analysis of all files".to_string())
|
||||
.await;
|
||||
|
||||
// Should hit the iteration limit (or parse error if LLM doesn't follow format)
|
||||
match result {
|
||||
Err(e) => {
|
||||
let error_str = format!("{}", e);
|
||||
// Accept either iteration limit error or parse error (LLM didn't follow ReAct format)
|
||||
assert!(
|
||||
error_str.contains("Maximum iterations")
|
||||
|| error_str.contains("2")
|
||||
|| error_str.contains("parse"),
|
||||
"Expected iteration limit or parse error, got: {}",
|
||||
error_str
|
||||
);
|
||||
println!("Test passed: agent stopped with error: {}", error_str);
|
||||
}
|
||||
Ok(_) => {
|
||||
// It's possible the LLM completed within 2 iterations
|
||||
println!("Agent completed within iteration limit");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires Ollama
|
||||
async fn test_agent_tool_budget_enforcement() {
|
||||
let provider = Arc::new(RemoteMcpClient::new().await.unwrap());
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: 3, // Very low iteration limit to enforce budget
|
||||
model: "llama3.2".to_string(),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: None,
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
|
||||
// Query that would require many tool calls
|
||||
let result = executor
|
||||
.run("Read every file in the project and summarize them all".to_string())
|
||||
.await;
|
||||
|
||||
// Should hit the tool call budget (or parse error if LLM doesn't follow format)
|
||||
match result {
|
||||
Err(e) => {
|
||||
let error_str = format!("{}", e);
|
||||
// Accept either budget error or parse error (LLM didn't follow ReAct format)
|
||||
assert!(
|
||||
error_str.contains("Maximum iterations")
|
||||
|| error_str.contains("budget")
|
||||
|| error_str.contains("parse"),
|
||||
"Expected budget or parse error, got: {}",
|
||||
error_str
|
||||
);
|
||||
println!("Test passed: agent stopped with error: {}", error_str);
|
||||
}
|
||||
Ok(_) => {
|
||||
println!("Agent completed within tool budget");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function to create a test executor
|
||||
// For parsing tests, we don't need a real connection
|
||||
async fn create_test_executor() -> AgentExecutor {
|
||||
// For parsing tests, we can accept the error from RemoteMcpClient::new()
|
||||
// since we're only testing parse_response which doesn't use the MCP client
|
||||
let provider = match RemoteMcpClient::new().await {
|
||||
Ok(client) => Arc::new(client),
|
||||
Err(_) => {
|
||||
// If MCP server binary doesn't exist, parsing tests can still run
|
||||
// by using a dummy client that will never be called
|
||||
// This is a workaround for unit tests that only need parse_response
|
||||
panic!("MCP server binary not found - build the project first with: cargo build --all");
|
||||
}
|
||||
};
|
||||
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig::default();
|
||||
AgentExecutor::new(provider, mcp_client, config)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_config_defaults() {
|
||||
let config = AgentConfig::default();
|
||||
|
||||
assert_eq!(config.max_iterations, 15);
|
||||
assert_eq!(config.model, "llama3.2:latest");
|
||||
assert_eq!(config.temperature, Some(0.7));
|
||||
assert_eq!(config.system_prompt, None);
|
||||
assert!(config.sub_agents.is_empty());
|
||||
// max_tool_calls field removed - agent now tracks iterations instead
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_config_custom() {
|
||||
let config = AgentConfig {
|
||||
max_iterations: 15,
|
||||
model: "custom-model".to_string(),
|
||||
temperature: Some(0.5),
|
||||
max_tokens: Some(2000),
|
||||
system_prompt: Some("Custom prompt".to_string()),
|
||||
sub_agents: Vec::new(),
|
||||
};
|
||||
|
||||
assert_eq!(config.max_iterations, 15);
|
||||
assert_eq!(config.model, "custom-model");
|
||||
assert_eq!(config.temperature, Some(0.5));
|
||||
assert_eq!(config.max_tokens, Some(2000));
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-core"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
description = "Core traits and types for OWLEN LLM client"
|
||||
|
||||
[dependencies]
|
||||
owlen-ui-common = { path = "../owlen-ui-common" }
|
||||
anyhow = { workspace = true }
|
||||
log = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
unicode-segmentation = "1.11"
|
||||
unicode-width = "0.2"
|
||||
uuid = { workspace = true }
|
||||
textwrap = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
futures-util = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
toml = { workspace = true }
|
||||
shellexpand = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
jsonschema = { workspace = true }
|
||||
which = { workspace = true }
|
||||
nix = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
ring = { workspace = true }
|
||||
keyring = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["default"] }
|
||||
path-clean = "1.0"
|
||||
tokio-stream = { workspace = true }
|
||||
tokio-tungstenite = "0.21"
|
||||
tungstenite = "0.21"
|
||||
ollama-rs = { version = "=0.3.2", features = ["stream", "headers"] }
|
||||
once_cell = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
httpmock = "0.7"
|
||||
wiremock = "0.6"
|
||||
@@ -1,12 +0,0 @@
|
||||
# Owlen Core
|
||||
|
||||
This crate provides the core abstractions and data structures for the Owlen ecosystem.
|
||||
|
||||
It defines the essential traits and types that enable communication with various LLM providers, manage sessions, and handle configuration.
|
||||
|
||||
## Key Components
|
||||
|
||||
- **`Provider` trait**: The fundamental abstraction for all LLM providers. Implement this trait to add support for a new provider.
|
||||
- **`Session`**: Represents a single conversation, managing message history and context.
|
||||
- **`Model`**: Defines the structure for LLM models, including their names and properties.
|
||||
- **Configuration**: Handles loading and parsing of the application's configuration.
|
||||
@@ -1,12 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS conversations (
|
||||
id TEXT PRIMARY KEY,
|
||||
name TEXT,
|
||||
description TEXT,
|
||||
model TEXT NOT NULL,
|
||||
message_count INTEGER NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL,
|
||||
data TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_conversations_updated_at ON conversations(updated_at DESC);
|
||||
@@ -1,7 +0,0 @@
|
||||
CREATE TABLE IF NOT EXISTS secure_items (
|
||||
key TEXT PRIMARY KEY,
|
||||
nonce BLOB NOT NULL,
|
||||
ciphertext BLOB NOT NULL,
|
||||
created_at INTEGER NOT NULL,
|
||||
updated_at INTEGER NOT NULL
|
||||
);
|
||||
@@ -1,478 +0,0 @@
|
||||
//! Agentic execution loop with ReAct pattern support.
|
||||
//!
|
||||
//! This module provides the core agent orchestration logic that allows an LLM
|
||||
//! to reason about tasks, execute tools, and observe results in an iterative loop.
|
||||
|
||||
use crate::Provider;
|
||||
use crate::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::types::{ChatParameters, ChatRequest, Message, MessageAttachment};
|
||||
use crate::{Error, Result, SubAgentSpec};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Maximum number of agent iterations before stopping
|
||||
const DEFAULT_MAX_ITERATIONS: usize = 15;
|
||||
|
||||
/// Parsed response from the LLM in ReAct format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum LlmResponse {
|
||||
/// LLM wants to execute a tool
|
||||
ToolCall {
|
||||
thought: String,
|
||||
tool_name: String,
|
||||
arguments: serde_json::Value,
|
||||
},
|
||||
/// LLM has reached a final answer
|
||||
FinalAnswer { thought: String, answer: String },
|
||||
/// LLM is just reasoning without taking action
|
||||
Reasoning { thought: String },
|
||||
}
|
||||
|
||||
fn assemble_prompt_with_tools_and_subagents(
|
||||
base_prompt: &str,
|
||||
tools: &[McpToolDescriptor],
|
||||
sub_agents: &[SubAgentSpec],
|
||||
) -> String {
|
||||
let mut prompt = base_prompt.trim().to_string();
|
||||
prompt.push_str("\n\nYou have access to the following tools:\n");
|
||||
for tool in tools {
|
||||
prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
|
||||
}
|
||||
append_subagent_guidance(&mut prompt, sub_agents);
|
||||
prompt
|
||||
}
|
||||
|
||||
fn append_subagent_guidance(prompt: &mut String, sub_agents: &[SubAgentSpec]) {
|
||||
if sub_agents.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
prompt.push_str("\nYou may delegate focused tasks to the following specialised sub-agents:\n");
|
||||
for sub in sub_agents {
|
||||
prompt.push_str(&format!(
|
||||
"- {}: {}\n{}\n",
|
||||
sub.name.as_deref().unwrap_or(sub.id.as_str()),
|
||||
sub.description
|
||||
.as_deref()
|
||||
.unwrap_or("No description provided."),
|
||||
sub.prompt.trim()
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse error when LLM response doesn't match expected format
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ParseError {
|
||||
#[error("No recognizable pattern found in response")]
|
||||
NoPattern,
|
||||
#[error("Missing required field: {0}")]
|
||||
MissingField(String),
|
||||
#[error("Invalid JSON in ACTION_INPUT: {0}")]
|
||||
InvalidJson(String),
|
||||
}
|
||||
|
||||
/// Result of an agent execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentResult {
|
||||
/// Final answer from the agent
|
||||
pub answer: String,
|
||||
/// Number of iterations taken
|
||||
pub iterations: usize,
|
||||
/// All messages exchanged during execution
|
||||
pub messages: Vec<Message>,
|
||||
/// Whether the agent completed successfully
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
/// Configuration for agent execution
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentConfig {
|
||||
/// Maximum number of iterations
|
||||
pub max_iterations: usize,
|
||||
/// Model to use for reasoning
|
||||
pub model: String,
|
||||
/// Temperature for LLM sampling
|
||||
pub temperature: Option<f32>,
|
||||
/// Max tokens per LLM call
|
||||
pub max_tokens: Option<u32>,
|
||||
/// Optional override for the system prompt presented to the LLM.
|
||||
pub system_prompt: Option<String>,
|
||||
/// Optional sub-agent prompts exposed to the executor.
|
||||
pub sub_agents: Vec<SubAgentSpec>,
|
||||
}
|
||||
|
||||
impl Default for AgentConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_iterations: DEFAULT_MAX_ITERATIONS,
|
||||
model: "llama3.2:latest".to_string(),
|
||||
temperature: Some(0.7),
|
||||
max_tokens: Some(4096),
|
||||
system_prompt: None,
|
||||
sub_agents: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Agent executor that orchestrates the ReAct loop
|
||||
pub struct AgentExecutor {
|
||||
/// LLM provider for reasoning
|
||||
llm_client: Arc<dyn Provider>,
|
||||
/// MCP client for tool execution
|
||||
tool_client: Arc<dyn McpClient>,
|
||||
/// Agent configuration
|
||||
config: AgentConfig,
|
||||
}
|
||||
|
||||
impl AgentExecutor {
|
||||
/// Create a new agent executor
|
||||
pub fn new(
|
||||
llm_client: Arc<dyn Provider>,
|
||||
tool_client: Arc<dyn McpClient>,
|
||||
config: AgentConfig,
|
||||
) -> Self {
|
||||
Self {
|
||||
llm_client,
|
||||
tool_client,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Run the agent loop with the given query
|
||||
pub async fn run(&self, query: String) -> Result<AgentResult> {
|
||||
self.run_with_attachments(query, Vec::new()).await
|
||||
}
|
||||
|
||||
/// Run the agent loop with an initial multimodal payload.
|
||||
pub async fn run_with_attachments(
|
||||
&self,
|
||||
query: String,
|
||||
attachments: Vec<MessageAttachment>,
|
||||
) -> Result<AgentResult> {
|
||||
let mut messages = vec![Message::user(query).with_attachments(attachments)];
|
||||
let tools = self.discover_tools().await?;
|
||||
|
||||
for iteration in 0..self.config.max_iterations {
|
||||
let prompt = self.build_react_prompt(&messages, &tools);
|
||||
let response = self.generate_llm_response(prompt).await?;
|
||||
|
||||
match self.parse_response(&response)? {
|
||||
LlmResponse::ToolCall {
|
||||
thought,
|
||||
tool_name,
|
||||
arguments,
|
||||
} => {
|
||||
// Add assistant's reasoning
|
||||
messages.push(Message::assistant(format!(
|
||||
"THOUGHT: {}\nACTION: {}\nACTION_INPUT: {}",
|
||||
thought,
|
||||
tool_name,
|
||||
serde_json::to_string_pretty(&arguments).unwrap_or_default()
|
||||
)));
|
||||
|
||||
// Execute the tool
|
||||
let result = self.execute_tool(&tool_name, arguments).await?;
|
||||
|
||||
// Add observation
|
||||
messages.push(Message::tool(
|
||||
tool_name.clone(),
|
||||
format!(
|
||||
"OBSERVATION: {}",
|
||||
serde_json::to_string_pretty(&result.output).unwrap_or_default()
|
||||
),
|
||||
));
|
||||
}
|
||||
LlmResponse::FinalAnswer { thought, answer } => {
|
||||
messages.push(Message::assistant(format!(
|
||||
"THOUGHT: {}\nFINAL_ANSWER: {}",
|
||||
thought, answer
|
||||
)));
|
||||
return Ok(AgentResult {
|
||||
answer,
|
||||
iterations: iteration + 1,
|
||||
messages,
|
||||
success: true,
|
||||
});
|
||||
}
|
||||
LlmResponse::Reasoning { thought } => {
|
||||
messages.push(Message::assistant(format!("THOUGHT: {}", thought)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Max iterations reached
|
||||
Ok(AgentResult {
|
||||
answer: "Maximum iterations reached without finding a final answer".to_string(),
|
||||
iterations: self.config.max_iterations,
|
||||
messages,
|
||||
success: false,
|
||||
})
|
||||
}
|
||||
|
||||
/// Discover available tools from the MCP client
|
||||
async fn discover_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
self.tool_client.list_tools().await
|
||||
}
|
||||
|
||||
/// Build a ReAct-formatted prompt with available tools
|
||||
fn build_react_prompt(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
tools: &[McpToolDescriptor],
|
||||
) -> Vec<Message> {
|
||||
let mut prompt_messages = Vec::new();
|
||||
|
||||
// System prompt with ReAct instructions
|
||||
let system_prompt = self.build_system_prompt(tools);
|
||||
prompt_messages.push(Message::system(system_prompt));
|
||||
|
||||
// Add conversation history
|
||||
prompt_messages.extend_from_slice(messages);
|
||||
|
||||
prompt_messages
|
||||
}
|
||||
|
||||
/// Build the system prompt with ReAct format and tool descriptions
|
||||
fn build_system_prompt(&self, tools: &[McpToolDescriptor]) -> String {
|
||||
if let Some(custom) = &self.config.system_prompt {
|
||||
return assemble_prompt_with_tools_and_subagents(
|
||||
custom,
|
||||
tools,
|
||||
&self.config.sub_agents,
|
||||
);
|
||||
}
|
||||
|
||||
let mut prompt = String::from(
|
||||
"You are an AI assistant that uses the ReAct (Reasoning and Acting) pattern to solve tasks.\n\n\
|
||||
You have access to the following tools:\n\n",
|
||||
);
|
||||
|
||||
for tool in tools {
|
||||
prompt.push_str(&format!("- {}: {}\n", tool.name, tool.description));
|
||||
}
|
||||
|
||||
prompt.push_str(
|
||||
"\nUse the following format:\n\n\
|
||||
THOUGHT: Your reasoning about what to do next\n\
|
||||
ACTION: tool_name\n\
|
||||
ACTION_INPUT: {\"param\": \"value\"}\n\n\
|
||||
You will receive:\n\
|
||||
OBSERVATION: The result of the tool execution\n\n\
|
||||
Continue this process until you have enough information, then provide:\n\
|
||||
THOUGHT: Final reasoning\n\
|
||||
FINAL_ANSWER: Your comprehensive answer\n\n\
|
||||
Important:\n\
|
||||
- Always start with THOUGHT to explain your reasoning\n\
|
||||
- ACTION must be one of the available tools\n\
|
||||
- ACTION_INPUT must be valid JSON\n\
|
||||
- Use FINAL_ANSWER only when you have sufficient information\n",
|
||||
);
|
||||
|
||||
append_subagent_guidance(&mut prompt, &self.config.sub_agents);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
/// Generate an LLM response
|
||||
async fn generate_llm_response(&self, messages: Vec<Message>) -> Result<String> {
|
||||
let request = ChatRequest {
|
||||
model: self.config.model.clone(),
|
||||
messages,
|
||||
parameters: ChatParameters {
|
||||
temperature: self.config.temperature,
|
||||
max_tokens: self.config.max_tokens,
|
||||
stream: false,
|
||||
..Default::default()
|
||||
},
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = self.llm_client.send_prompt(request).await?;
|
||||
Ok(response.message.content)
|
||||
}
|
||||
/// Parse LLM response into structured format
|
||||
pub fn parse_response(&self, text: &str) -> Result<LlmResponse> {
|
||||
let lines: Vec<&str> = text.lines().collect();
|
||||
let mut thought = String::new();
|
||||
let mut action = String::new();
|
||||
let mut action_input = String::new();
|
||||
let mut final_answer = String::new();
|
||||
|
||||
let mut i = 0;
|
||||
while i < lines.len() {
|
||||
let line = lines[i].trim();
|
||||
|
||||
if line.starts_with("THOUGHT:") {
|
||||
thought = line
|
||||
.strip_prefix("THOUGHT:")
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.to_string();
|
||||
// Collect multi-line thoughts
|
||||
i += 1;
|
||||
while i < lines.len()
|
||||
&& !lines[i].trim().starts_with("ACTION")
|
||||
&& !lines[i].trim().starts_with("FINAL_ANSWER")
|
||||
{
|
||||
if !lines[i].trim().is_empty() {
|
||||
thought.push(' ');
|
||||
thought.push_str(lines[i].trim());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("ACTION:") {
|
||||
action = line
|
||||
.strip_prefix("ACTION:")
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.to_string();
|
||||
i += 1;
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("ACTION_INPUT:") {
|
||||
action_input = line
|
||||
.strip_prefix("ACTION_INPUT:")
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.to_string();
|
||||
// Collect multi-line JSON
|
||||
i += 1;
|
||||
while i < lines.len()
|
||||
&& !lines[i].trim().starts_with("THOUGHT")
|
||||
&& !lines[i].trim().starts_with("ACTION")
|
||||
{
|
||||
action_input.push(' ');
|
||||
action_input.push_str(lines[i].trim());
|
||||
i += 1;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if line.starts_with("FINAL_ANSWER:") {
|
||||
final_answer = line
|
||||
.strip_prefix("FINAL_ANSWER:")
|
||||
.unwrap_or("")
|
||||
.trim()
|
||||
.to_string();
|
||||
// Collect multi-line answer
|
||||
i += 1;
|
||||
while i < lines.len() {
|
||||
if !lines[i].trim().is_empty() {
|
||||
final_answer.push(' ');
|
||||
final_answer.push_str(lines[i].trim());
|
||||
}
|
||||
i += 1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
i += 1;
|
||||
}
|
||||
|
||||
// Determine response type
|
||||
if !final_answer.is_empty() {
|
||||
return Ok(LlmResponse::FinalAnswer {
|
||||
thought,
|
||||
answer: final_answer,
|
||||
});
|
||||
}
|
||||
|
||||
if !action.is_empty() {
|
||||
let arguments = if action_input.is_empty() {
|
||||
serde_json::json!({})
|
||||
} else {
|
||||
serde_json::from_str(&action_input)
|
||||
.map_err(|e| Error::Agent(ParseError::InvalidJson(e.to_string()).to_string()))?
|
||||
};
|
||||
|
||||
return Ok(LlmResponse::ToolCall {
|
||||
thought,
|
||||
tool_name: action,
|
||||
arguments,
|
||||
});
|
||||
}
|
||||
|
||||
if !thought.is_empty() {
|
||||
return Ok(LlmResponse::Reasoning { thought });
|
||||
}
|
||||
|
||||
Err(Error::Agent(ParseError::NoPattern.to_string()))
|
||||
}
|
||||
|
||||
/// Execute a tool call
|
||||
async fn execute_tool(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
arguments: serde_json::Value,
|
||||
) -> Result<McpToolResponse> {
|
||||
let call = McpToolCall {
|
||||
name: tool_name.to_string(),
|
||||
arguments,
|
||||
};
|
||||
self.tool_client.call_tool(call).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::llm::test_utils::MockProvider;
|
||||
use crate::mcp::test_utils::MockMcpClient;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
#[test]
|
||||
fn test_parse_tool_call() {
|
||||
let executor = AgentExecutor {
|
||||
llm_client: Arc::new(MockProvider::default()),
|
||||
tool_client: Arc::new(MockMcpClient),
|
||||
config: AgentConfig::default(),
|
||||
};
|
||||
|
||||
let text = r#"
|
||||
THOUGHT: I need to search for information about Rust
|
||||
ACTION: web_search
|
||||
ACTION_INPUT: {"query": "Rust programming language"}
|
||||
"#;
|
||||
|
||||
let result = executor.parse_response(text).unwrap();
|
||||
match result {
|
||||
LlmResponse::ToolCall {
|
||||
thought,
|
||||
tool_name,
|
||||
arguments,
|
||||
} => {
|
||||
assert!(thought.contains("search for information"));
|
||||
assert!(matches!(tool_name.as_str(), WEB_SEARCH_TOOL_NAME));
|
||||
assert_eq!(arguments["query"], "Rust programming language");
|
||||
}
|
||||
_ => panic!("Expected ToolCall"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_final_answer() {
|
||||
let executor = AgentExecutor {
|
||||
llm_client: Arc::new(MockProvider::default()),
|
||||
tool_client: Arc::new(MockMcpClient),
|
||||
config: AgentConfig::default(),
|
||||
};
|
||||
|
||||
let text = r#"
|
||||
THOUGHT: I now have enough information to answer
|
||||
FINAL_ANSWER: Rust is a systems programming language focused on safety and performance.
|
||||
"#;
|
||||
|
||||
let result = executor.parse_response(text).unwrap();
|
||||
match result {
|
||||
LlmResponse::FinalAnswer { thought, answer } => {
|
||||
assert!(thought.contains("enough information"));
|
||||
assert!(answer.contains("Rust is a systems programming language"));
|
||||
}
|
||||
_ => panic!("Expected FinalAnswer"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,462 +0,0 @@
|
||||
use crate::{Error, Result};
|
||||
use serde::Deserialize;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
/// Maximum allowed size (bytes) for an agent prompt file.
|
||||
const MAX_PROMPT_SIZE_BYTES: usize = 128 * 1024;
|
||||
|
||||
/// Definition of a sub-agent that can be referenced by the primary agent prompt.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubAgentSpec {
|
||||
pub id: String,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub prompt: String,
|
||||
}
|
||||
|
||||
/// Fully resolved agent profile loaded from configuration files.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AgentProfile {
|
||||
pub id: String,
|
||||
pub name: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub system_prompt: String,
|
||||
pub model: Option<String>,
|
||||
pub temperature: Option<f32>,
|
||||
pub max_iterations: Option<usize>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub tags: Vec<String>,
|
||||
pub sub_agents: Vec<SubAgentSpec>,
|
||||
pub source_path: PathBuf,
|
||||
}
|
||||
|
||||
impl AgentProfile {
|
||||
pub fn display_name(&self) -> &str {
|
||||
self.name.as_deref().unwrap_or(self.id.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Registry responsible for discovering and loading user-defined agent profiles.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct AgentRegistry {
|
||||
profiles: Vec<AgentProfile>,
|
||||
index: HashMap<String, usize>,
|
||||
search_paths: Vec<PathBuf>,
|
||||
}
|
||||
|
||||
impl AgentRegistry {
|
||||
/// Build a registry by discovering configuration in standard locations.
|
||||
pub fn discover(project_hint: Option<&Path>) -> Result<Self> {
|
||||
let mut search_paths = Vec::new();
|
||||
|
||||
if let Some(config_dir) = dirs::config_dir() {
|
||||
search_paths.push(config_dir.join("owlen").join("agents"));
|
||||
}
|
||||
|
||||
search_paths.extend(discover_project_agent_paths(project_hint));
|
||||
|
||||
if let Ok(env) = std::env::var("OWLEN_AGENTS_PATH") {
|
||||
for path in env.split(std::path::MAIN_SEPARATOR) {
|
||||
if !path.trim().is_empty() {
|
||||
search_paths.push(PathBuf::from(path));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self::load_from_paths(search_paths)
|
||||
}
|
||||
|
||||
/// Build the registry from explicit paths.
|
||||
pub fn load_from_paths(paths: Vec<PathBuf>) -> Result<Self> {
|
||||
let mut registry = Self {
|
||||
profiles: Vec::new(),
|
||||
index: HashMap::new(),
|
||||
search_paths: paths.clone(),
|
||||
};
|
||||
|
||||
for path in paths {
|
||||
registry.load_directory(&path)?;
|
||||
}
|
||||
|
||||
Ok(registry)
|
||||
}
|
||||
|
||||
/// Return the list of discovered agent profiles.
|
||||
pub fn profiles(&self) -> &[AgentProfile] {
|
||||
&self.profiles
|
||||
}
|
||||
|
||||
/// Return a profile by identifier.
|
||||
pub fn get(&self, id: &str) -> Option<&AgentProfile> {
|
||||
self.index.get(id).and_then(|idx| self.profiles.get(*idx))
|
||||
}
|
||||
|
||||
/// Reload all search paths, replacing existing profiles.
|
||||
pub fn reload(&mut self) -> Result<()> {
|
||||
let paths = self.search_paths.clone();
|
||||
self.profiles.clear();
|
||||
self.index.clear();
|
||||
|
||||
for path in paths {
|
||||
self.load_directory(&path)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_directory(&mut self, dir: &Path) -> Result<()> {
|
||||
if !dir.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut files = Vec::new();
|
||||
collect_agent_files(dir, &mut files)?;
|
||||
files.sort();
|
||||
|
||||
for file in files {
|
||||
match load_agent_file(&file) {
|
||||
Ok(mut profiles) => {
|
||||
for profile in profiles.drain(..) {
|
||||
let id = profile.id.clone();
|
||||
if let Some(existing) = self.index.get(&id).copied() {
|
||||
// Later search paths override earlier ones.
|
||||
self.profiles[existing] = profile;
|
||||
} else {
|
||||
let idx = self.profiles.len();
|
||||
self.profiles.push(profile);
|
||||
self.index.insert(id, idx);
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
return Err(Error::Config(format!(
|
||||
"Failed to load agent definition {}: {err}",
|
||||
file.display()
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn collect_agent_files(dir: &Path, files: &mut Vec<PathBuf>) -> Result<()> {
|
||||
if !dir.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
for entry in fs::read_dir(dir).map_err(Error::Io)? {
|
||||
let entry = entry.map_err(Error::Io)?;
|
||||
let path = entry.path();
|
||||
if path.is_dir() {
|
||||
collect_agent_files(&path, files)?;
|
||||
} else if path
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())
|
||||
.map(|ext| ext.eq_ignore_ascii_case("toml"))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
files.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn discover_project_agent_paths(project_hint: Option<&Path>) -> Vec<PathBuf> {
|
||||
let mut results = Vec::new();
|
||||
|
||||
let mut current = project_hint
|
||||
.map(PathBuf::from)
|
||||
.or_else(|| std::env::current_dir().ok());
|
||||
|
||||
while let Some(path) = current {
|
||||
let candidate = path.join(".owlen").join("agents");
|
||||
if candidate.exists() {
|
||||
results.push(candidate);
|
||||
}
|
||||
|
||||
current = path.parent().map(PathBuf::from);
|
||||
}
|
||||
|
||||
results
|
||||
}
|
||||
|
||||
fn load_agent_file(path: &Path) -> Result<Vec<AgentProfile>> {
|
||||
let raw = fs::read_to_string(path).map_err(Error::Io)?;
|
||||
if raw.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let document: AgentDocument = toml::from_str(&raw)
|
||||
.map_err(|err| Error::Config(format!("Unable to parse {}: {err}", path.display())))?;
|
||||
|
||||
let mut profiles = Vec::new();
|
||||
|
||||
if document.agents.is_empty() {
|
||||
let single: SingleAgentFile = toml::from_str(&raw).map_err(|err| {
|
||||
Error::Config(format!(
|
||||
"Agent definition {} must contain either [[agents]] tables or top-level id/prompt fields: {err}",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
profiles.push(resolve_agent_entry(path, &single.entry)?);
|
||||
return Ok(profiles);
|
||||
}
|
||||
|
||||
for entry in document.agents {
|
||||
profiles.push(resolve_agent_entry(path, &entry)?);
|
||||
}
|
||||
|
||||
Ok(profiles)
|
||||
}
|
||||
|
||||
fn resolve_agent_entry(path: &Path, entry: &AgentEntry) -> Result<AgentProfile> {
|
||||
let base_dir = path
|
||||
.parent()
|
||||
.map(PathBuf::from)
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
|
||||
let system_prompt = entry
|
||||
.prompt
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
Error::Config(format!(
|
||||
"Agent '{}' in {} is missing a `prompt` value",
|
||||
entry.id,
|
||||
path.display()
|
||||
))
|
||||
})?
|
||||
.resolve(&base_dir)?;
|
||||
|
||||
let mut sub_agents = Vec::new();
|
||||
for (id, sub) in &entry.sub_agents {
|
||||
let prompt = sub.prompt.resolve(&base_dir)?;
|
||||
sub_agents.push(SubAgentSpec {
|
||||
id: id.clone(),
|
||||
name: sub.name.clone(),
|
||||
description: sub.description.clone(),
|
||||
prompt,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(AgentProfile {
|
||||
id: entry.id.clone(),
|
||||
name: entry.name.clone(),
|
||||
description: entry.description.clone(),
|
||||
system_prompt,
|
||||
model: entry.parameters.as_ref().and_then(|p| p.model.clone()),
|
||||
temperature: entry.parameters.as_ref().and_then(|p| p.temperature),
|
||||
max_iterations: entry.parameters.as_ref().and_then(|p| p.max_iterations),
|
||||
max_tokens: entry.parameters.as_ref().and_then(|p| p.max_tokens),
|
||||
tags: entry.tags.clone().unwrap_or_default(),
|
||||
sub_agents,
|
||||
source_path: path.to_path_buf(),
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentDocument {
|
||||
#[serde(default = "default_schema_version")]
|
||||
_version: String,
|
||||
#[serde(default)]
|
||||
agents: Vec<AgentEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SingleAgentFile {
|
||||
#[serde(default = "default_schema_version")]
|
||||
_version: String,
|
||||
#[serde(flatten)]
|
||||
entry: AgentEntry,
|
||||
}
|
||||
|
||||
fn default_schema_version() -> String {
|
||||
"1".to_string()
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentEntry {
|
||||
id: String,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
#[serde(default)]
|
||||
tags: Option<Vec<String>>,
|
||||
#[serde(default)]
|
||||
prompt: Option<PromptSpec>,
|
||||
#[serde(default)]
|
||||
parameters: Option<AgentParameters>,
|
||||
#[serde(default)]
|
||||
sub_agents: HashMap<String, SubAgentEntry>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AgentParameters {
|
||||
#[serde(default)]
|
||||
model: Option<String>,
|
||||
#[serde(default)]
|
||||
temperature: Option<f32>,
|
||||
#[serde(default)]
|
||||
max_iterations: Option<usize>,
|
||||
#[serde(default)]
|
||||
max_tokens: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SubAgentEntry {
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
description: Option<String>,
|
||||
prompt: PromptSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
enum PromptSpec {
|
||||
Inline(String),
|
||||
Source { file: String },
|
||||
}
|
||||
|
||||
impl PromptSpec {
|
||||
fn resolve(&self, base_dir: &Path) -> Result<String> {
|
||||
match self {
|
||||
PromptSpec::Inline(value) => Ok(value.trim().to_string()),
|
||||
PromptSpec::Source { file } => {
|
||||
let path = if Path::new(file).is_absolute() {
|
||||
PathBuf::from(file)
|
||||
} else {
|
||||
base_dir.join(file)
|
||||
};
|
||||
|
||||
let data = fs::read(&path).map_err(Error::Io)?;
|
||||
if data.len() > MAX_PROMPT_SIZE_BYTES {
|
||||
return Err(Error::Config(format!(
|
||||
"Prompt file {} exceeds the maximum supported size ({MAX_PROMPT_SIZE_BYTES} bytes)",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
|
||||
let text = String::from_utf8(data).map_err(|_| {
|
||||
Error::Config(format!("Prompt file {} is not valid UTF-8", path.display()))
|
||||
})?;
|
||||
|
||||
Ok(text.trim().to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::io::Write;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn load_simple_agent() {
|
||||
let dir = tempdir().expect("temp dir");
|
||||
let agent_dir = dir.path().join("agents");
|
||||
fs::create_dir_all(&agent_dir).unwrap();
|
||||
|
||||
let mut file = fs::File::create(agent_dir.join("support.toml")).unwrap();
|
||||
writeln!(
|
||||
file,
|
||||
r#"
|
||||
version = "1"
|
||||
|
||||
[[agents]]
|
||||
id = "support"
|
||||
name = "Support Specialist"
|
||||
description = "Handles user support tickets."
|
||||
prompt = "You are a helpful support assistant."
|
||||
|
||||
[agents.parameters]
|
||||
model = "gpt-4"
|
||||
max_iterations = 8
|
||||
temperature = 0.2
|
||||
|
||||
[agents.sub_agents.first_line]
|
||||
name = "First-line support"
|
||||
description = "Handles simple issues"
|
||||
prompt = "Escalate complex issues."
|
||||
"#
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let registry = AgentRegistry::load_from_paths(vec![agent_dir]).unwrap();
|
||||
assert_eq!(registry.profiles.len(), 1);
|
||||
|
||||
let profile = registry.get("support").unwrap();
|
||||
assert_eq!(profile.display_name(), "Support Specialist");
|
||||
assert_eq!(
|
||||
profile.system_prompt,
|
||||
"You are a helpful support assistant."
|
||||
);
|
||||
assert_eq!(profile.model.as_deref(), Some("gpt-4"));
|
||||
assert_eq!(profile.max_iterations, Some(8));
|
||||
assert_eq!(profile.sub_agents.len(), 1);
|
||||
assert_eq!(profile.sub_agents[0].id, "first_line");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_from_file_resolves_relative_path() {
|
||||
let dir = tempdir().expect("temp dir");
|
||||
let agent_dir = dir.path().join(".owlen").join("agents");
|
||||
let prompt_dir = agent_dir.join("prompts");
|
||||
fs::create_dir_all(&prompt_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
prompt_dir.join("researcher.md"),
|
||||
"Research the latest documentation updates.",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
fs::write(
|
||||
agent_dir.join("doc.toml"),
|
||||
r#"
|
||||
version = "1"
|
||||
|
||||
[[agents]]
|
||||
id = "docs"
|
||||
prompt = { file = "prompts/researcher.md" }
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let registry = AgentRegistry::load_from_paths(vec![agent_dir]).unwrap();
|
||||
let profile = registry.get("docs").unwrap();
|
||||
assert_eq!(
|
||||
profile.system_prompt,
|
||||
"Research the latest documentation updates."
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_agent_from_flat_document() {
|
||||
let dir = tempdir().expect("temp dir");
|
||||
let agent_dir = dir.path().join("agents");
|
||||
fs::create_dir_all(&agent_dir).unwrap();
|
||||
|
||||
fs::write(
|
||||
agent_dir.join("flat.toml"),
|
||||
r#"
|
||||
version = "1"
|
||||
id = "flat"
|
||||
name = "Flat Agent"
|
||||
prompt = "Operate using flat configuration."
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let registry = AgentRegistry::load_from_paths(vec![agent_dir]).unwrap();
|
||||
let profile = registry.get("flat").expect("profile present");
|
||||
assert_eq!(profile.display_name(), "Flat Agent");
|
||||
assert_eq!(profile.system_prompt, "Operate using flat configuration.");
|
||||
}
|
||||
}
|
||||
@@ -1,9 +0,0 @@
|
||||
//! High-level automation APIs for repository workflows (commit templating, PR review, etc.).
|
||||
|
||||
pub mod repo;
|
||||
|
||||
pub use repo::{
|
||||
CommitTemplate, CommitTemplateSection, DiffCaptureMode, DiffStatistics, FileChange,
|
||||
PullRequestContext, PullRequestReview, RepoAutomation, ReviewChecklistItem, ReviewFinding,
|
||||
ReviewSeverity, WorkflowStep,
|
||||
};
|
||||
@@ -1,943 +0,0 @@
|
||||
use crate::{Error, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::Command;
|
||||
use std::str;
|
||||
|
||||
/// Controls which diff snapshot should be inspected.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub enum DiffCaptureMode<'a> {
|
||||
/// Inspect staged changes (`git diff --cached`).
|
||||
Staged,
|
||||
/// Inspect unstaged working-tree changes (`git diff`).
|
||||
WorkingTree,
|
||||
/// Inspect the diff between two refs.
|
||||
Range { base: &'a str, head: &'a str },
|
||||
}
|
||||
|
||||
/// High-level automation entry-point for repository-centric workflows.
|
||||
pub struct RepoAutomation {
|
||||
repo_root: PathBuf,
|
||||
}
|
||||
|
||||
impl RepoAutomation {
|
||||
/// Discover the git repository root starting from the provided path.
|
||||
pub fn from_path(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let root = discover_repo_root(path.as_ref())?;
|
||||
Ok(Self { repo_root: root })
|
||||
}
|
||||
|
||||
/// Return the repository root on disk.
|
||||
pub fn repo_root(&self) -> &Path {
|
||||
&self.repo_root
|
||||
}
|
||||
|
||||
/// Generate a conventional commit template from the selected diff snapshot.
|
||||
pub fn generate_commit_template(&self, mode: DiffCaptureMode<'_>) -> Result<CommitTemplate> {
|
||||
let diff = capture_diff(&self.repo_root, mode)?;
|
||||
if diff.trim().is_empty() {
|
||||
return Err(Error::InvalidInput(
|
||||
"No changes detected for the selected diff snapshot.".to_string(),
|
||||
));
|
||||
}
|
||||
Ok(CommitTemplate::from_diff(&diff))
|
||||
}
|
||||
|
||||
/// Produce a pull-request style review for the given range of commits.
|
||||
pub fn generate_pr_review(
|
||||
&self,
|
||||
base: Option<&str>,
|
||||
head: Option<&str>,
|
||||
) -> Result<PullRequestReview> {
|
||||
let head = head.unwrap_or("HEAD");
|
||||
let base = base.unwrap_or("origin/main");
|
||||
let merge_base = resolve_merge_base(&self.repo_root, base, head)?;
|
||||
let diff = capture_range_diff(&self.repo_root, &merge_base, head)?;
|
||||
if diff.trim().is_empty() {
|
||||
return Err(Error::InvalidInput(
|
||||
"The computed diff between the selected refs is empty.".to_string(),
|
||||
));
|
||||
}
|
||||
let stats = DiffStatistics::from_diff(&diff);
|
||||
let context = PullRequestContext {
|
||||
title: format!("Diff of {head} vs {base}"),
|
||||
body: None,
|
||||
author: None,
|
||||
base_branch: base.to_string(),
|
||||
head_branch: head.to_string(),
|
||||
additions: stats.additions as u64,
|
||||
deletions: stats.deletions as u64,
|
||||
changed_files: stats.files as u64,
|
||||
html_url: None,
|
||||
};
|
||||
Ok(PullRequestReview::from_diff(context, &diff))
|
||||
}
|
||||
}
|
||||
|
||||
/// Summarised information about a changed file.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileChange {
|
||||
pub old_path: String,
|
||||
pub new_path: String,
|
||||
pub change: ChangeKind,
|
||||
pub additions: usize,
|
||||
pub deletions: usize,
|
||||
}
|
||||
|
||||
impl FileChange {
|
||||
pub fn primary_path(&self) -> &str {
|
||||
if !self.new_path.is_empty() {
|
||||
&self.new_path
|
||||
} else if !self.old_path.is_empty() {
|
||||
&self.old_path
|
||||
} else {
|
||||
""
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_test(&self) -> bool {
|
||||
FILE_TEST_HINTS
|
||||
.iter()
|
||||
.any(|hint| self.primary_path().contains(hint))
|
||||
}
|
||||
|
||||
pub fn is_doc(&self) -> bool {
|
||||
DOC_EXTENSIONS
|
||||
.iter()
|
||||
.any(|ext| self.primary_path().ends_with(ext))
|
||||
|| self.primary_path().starts_with("docs/")
|
||||
}
|
||||
|
||||
pub fn is_config(&self) -> bool {
|
||||
CONFIG_EXTENSIONS
|
||||
.iter()
|
||||
.any(|ext| self.primary_path().ends_with(ext))
|
||||
}
|
||||
|
||||
pub fn is_code(&self) -> bool {
|
||||
!self.is_doc() && !self.is_config()
|
||||
}
|
||||
}
|
||||
|
||||
/// Change classification for a diff entry.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ChangeKind {
|
||||
Added,
|
||||
Removed,
|
||||
Modified,
|
||||
Renamed { from: String },
|
||||
}
|
||||
|
||||
/// Structured conventional commit template recommendation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommitTemplate {
|
||||
pub prefix: String,
|
||||
pub summary: Vec<String>,
|
||||
pub sections: Vec<CommitTemplateSection>,
|
||||
pub workflow: Vec<WorkflowStep>,
|
||||
}
|
||||
|
||||
impl CommitTemplate {
|
||||
pub fn from_diff(diff: &str) -> Self {
|
||||
let changes = parse_file_changes(diff);
|
||||
let metrics = DiffMetrics::from_changes(&changes);
|
||||
let prefix = select_conventional_prefix(&metrics);
|
||||
let summary = changes.iter().map(format_change_summary).collect();
|
||||
let sections = build_commit_sections(&metrics);
|
||||
let workflow = vec![
|
||||
WorkflowStep::new(
|
||||
"Parse diff",
|
||||
format!("Identified {} files", metrics.changed_files),
|
||||
),
|
||||
WorkflowStep::new(
|
||||
"Choose conventional prefix",
|
||||
format!("Selected `{}` based on touched domains", prefix),
|
||||
),
|
||||
WorkflowStep::new("Assemble testing checklist", testing_summary(§ions)),
|
||||
];
|
||||
|
||||
Self {
|
||||
prefix: prefix.to_string(),
|
||||
summary,
|
||||
sections,
|
||||
workflow,
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the template as Markdown.
|
||||
pub fn render_markdown(&self) -> String {
|
||||
let mut out = String::new();
|
||||
out.push_str(&format!("{} <describe change>\n\n", self.prefix));
|
||||
if !self.summary.is_empty() {
|
||||
out.push_str("Summary:\n");
|
||||
for line in &self.summary {
|
||||
out.push_str("- ");
|
||||
out.push_str(line);
|
||||
out.push('\n');
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
for section in &self.sections {
|
||||
out.push_str(&format!("{}:\n", section.title));
|
||||
for line in §ion.lines {
|
||||
out.push_str("- ");
|
||||
out.push_str(line);
|
||||
out.push('\n');
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
out.trim_end().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// A named block of checklist items in the commit template.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommitTemplateSection {
|
||||
pub title: String,
|
||||
pub lines: Vec<String>,
|
||||
}
|
||||
|
||||
/// Metadata about a pull request / change range.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PullRequestContext {
|
||||
pub title: String,
|
||||
pub body: Option<String>,
|
||||
pub author: Option<String>,
|
||||
pub base_branch: String,
|
||||
pub head_branch: String,
|
||||
pub additions: u64,
|
||||
pub deletions: u64,
|
||||
pub changed_files: u64,
|
||||
pub html_url: Option<String>,
|
||||
}
|
||||
|
||||
/// Markdown-ready automation review artifact.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PullRequestReview {
|
||||
pub context: PullRequestContext,
|
||||
pub summary: String,
|
||||
pub highlights: Vec<String>,
|
||||
pub findings: Vec<ReviewFinding>,
|
||||
pub checklist: Vec<ReviewChecklistItem>,
|
||||
pub workflow: Vec<WorkflowStep>,
|
||||
}
|
||||
|
||||
impl PullRequestReview {
|
||||
pub fn from_diff(context: PullRequestContext, diff: &str) -> Self {
|
||||
let changes = parse_file_changes(diff);
|
||||
let metrics = DiffMetrics::from_changes(&changes);
|
||||
let highlights = build_highlights(&changes, &metrics);
|
||||
let findings = analyze_findings(diff, &changes, &metrics);
|
||||
let checklist = build_review_checklist(&metrics, &changes);
|
||||
let summary = format!(
|
||||
"{} files touched (+{}, -{}) · base {} → head {}",
|
||||
metrics.changed_files,
|
||||
metrics.total_additions,
|
||||
metrics.total_deletions,
|
||||
context.base_branch,
|
||||
context.head_branch
|
||||
);
|
||||
let workflow = vec![
|
||||
WorkflowStep::new(
|
||||
"Collect diff metadata",
|
||||
format!(
|
||||
"{} files, {} additions, {} deletions",
|
||||
metrics.changed_files, metrics.total_additions, metrics.total_deletions
|
||||
),
|
||||
),
|
||||
WorkflowStep::new(
|
||||
"Assess risk",
|
||||
format!(
|
||||
"{} potential issues detected",
|
||||
findings
|
||||
.iter()
|
||||
.filter(|finding| finding.severity != ReviewSeverity::Info)
|
||||
.count()
|
||||
),
|
||||
),
|
||||
WorkflowStep::new(
|
||||
"Prepare checklist",
|
||||
format!("{} follow-up items surfaced", checklist.len()),
|
||||
),
|
||||
];
|
||||
|
||||
Self {
|
||||
context,
|
||||
summary,
|
||||
highlights,
|
||||
findings,
|
||||
checklist,
|
||||
workflow,
|
||||
}
|
||||
}
|
||||
|
||||
/// Render a Markdown review body with sections for highlights, findings, and checklists.
|
||||
pub fn render_markdown(&self) -> String {
|
||||
let mut out = String::new();
|
||||
out.push_str(&format!("### Summary\n{}\n\n", self.summary));
|
||||
|
||||
if !self.highlights.is_empty() {
|
||||
out.push_str("### Highlights\n");
|
||||
for highlight in &self.highlights {
|
||||
out.push_str("- ");
|
||||
out.push_str(highlight);
|
||||
out.push('\n');
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
if !self.findings.is_empty() {
|
||||
out.push_str("### Findings\n");
|
||||
for finding in &self.findings {
|
||||
out.push_str(&format!(
|
||||
"- **{}**: {}\n",
|
||||
finding.severity.label(),
|
||||
finding.message
|
||||
));
|
||||
if !finding.locations.is_empty() {
|
||||
for loc in &finding.locations {
|
||||
out.push_str(" - ");
|
||||
out.push_str(loc);
|
||||
out.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
if !self.checklist.is_empty() {
|
||||
out.push_str("### Checklist\n");
|
||||
for item in &self.checklist {
|
||||
let box_mark = if item.completed { "[x]" } else { "[ ]" };
|
||||
out.push_str(&format!("- {} {}\n", box_mark, item.label));
|
||||
}
|
||||
out.push('\n');
|
||||
}
|
||||
|
||||
out.trim_end().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Individual review finding surfaced during heuristics.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReviewFinding {
|
||||
pub severity: ReviewSeverity,
|
||||
pub message: String,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub locations: Vec<String>,
|
||||
}
|
||||
|
||||
impl ReviewFinding {
|
||||
fn new(severity: ReviewSeverity, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
severity,
|
||||
message: message.into(),
|
||||
locations: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn with_location(mut self, location: impl Into<String>) -> Self {
|
||||
self.locations.push(location.into());
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Severity classification for review findings.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ReviewSeverity {
|
||||
Info,
|
||||
Low,
|
||||
Medium,
|
||||
High,
|
||||
}
|
||||
|
||||
impl ReviewSeverity {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
ReviewSeverity::Info => "info",
|
||||
ReviewSeverity::Low => "low",
|
||||
ReviewSeverity::Medium => "medium",
|
||||
ReviewSeverity::High => "high",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ReviewSeverity {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.label())
|
||||
}
|
||||
}
|
||||
|
||||
/// Checklist item exposed in reviews.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReviewChecklistItem {
|
||||
pub label: String,
|
||||
pub completed: bool,
|
||||
}
|
||||
|
||||
impl ReviewChecklistItem {
|
||||
fn new(label: impl Into<String>, completed: bool) -> Self {
|
||||
Self {
|
||||
label: label.into(),
|
||||
completed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// High-level workflow steps surfaced for SDK-style automation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WorkflowStep {
|
||||
pub label: String,
|
||||
pub outcome: String,
|
||||
}
|
||||
|
||||
impl WorkflowStep {
|
||||
pub fn new(label: impl Into<String>, outcome: impl Into<String>) -> Self {
|
||||
Self {
|
||||
label: label.into(),
|
||||
outcome: outcome.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Aggregate diff metrics derived from parsed file changes.
|
||||
#[derive(Debug, Default, Clone)]
|
||||
struct DiffMetrics {
|
||||
changed_files: usize,
|
||||
total_additions: usize,
|
||||
total_deletions: usize,
|
||||
test_files: usize,
|
||||
doc_files: usize,
|
||||
config_files: usize,
|
||||
code_files: usize,
|
||||
}
|
||||
|
||||
/// Summary statistics extracted from a diff.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DiffStatistics {
|
||||
pub files: usize,
|
||||
pub additions: usize,
|
||||
pub deletions: usize,
|
||||
}
|
||||
|
||||
impl DiffStatistics {
|
||||
pub fn from_diff(diff: &str) -> Self {
|
||||
Self {
|
||||
files: count_files(diff),
|
||||
additions: count_symbol(diff, '+'),
|
||||
deletions: count_symbol(diff, '-'),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convenience helper that converts a diff into aggregate statistics.
|
||||
pub fn summarize_diff(diff: &str) -> DiffStatistics {
|
||||
DiffStatistics::from_diff(diff)
|
||||
}
|
||||
|
||||
impl DiffMetrics {
|
||||
fn from_changes(changes: &[FileChange]) -> Self {
|
||||
let mut metrics = DiffMetrics {
|
||||
changed_files: changes.len(),
|
||||
..DiffMetrics::default()
|
||||
};
|
||||
for change in changes {
|
||||
metrics.total_additions += change.additions;
|
||||
metrics.total_deletions += change.deletions;
|
||||
if change.is_test() {
|
||||
metrics.test_files += 1;
|
||||
} else if change.is_doc() {
|
||||
metrics.doc_files += 1;
|
||||
} else if change.is_config() {
|
||||
metrics.config_files += 1;
|
||||
} else {
|
||||
metrics.code_files += 1;
|
||||
}
|
||||
}
|
||||
metrics
|
||||
}
|
||||
|
||||
fn has_tests(&self) -> bool {
|
||||
self.test_files > 0
|
||||
}
|
||||
|
||||
fn has_docs(&self) -> bool {
|
||||
self.doc_files > 0
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
// Diff parsing and heuristics
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
|
||||
fn parse_file_changes(diff: &str) -> Vec<FileChange> {
|
||||
let mut changes = Vec::new();
|
||||
let mut current: Option<FileChange> = None;
|
||||
|
||||
for line in diff.lines() {
|
||||
if line.starts_with("diff --git ") {
|
||||
if let Some(change) = current.take() {
|
||||
changes.push(change);
|
||||
}
|
||||
let mut parts = line.split_whitespace().skip(2);
|
||||
let old = parts.next().unwrap_or("a/unknown");
|
||||
let new = parts.next().unwrap_or("b/unknown");
|
||||
current = Some(FileChange {
|
||||
old_path: strip_path_prefix(old, "a/"),
|
||||
new_path: strip_path_prefix(new, "b/"),
|
||||
change: ChangeKind::Modified,
|
||||
additions: 0,
|
||||
deletions: 0,
|
||||
});
|
||||
} else if let Some(change) = current.as_mut() {
|
||||
if line.starts_with("new file mode") {
|
||||
change.change = ChangeKind::Added;
|
||||
} else if line.starts_with("deleted file mode") {
|
||||
change.change = ChangeKind::Removed;
|
||||
} else if line.starts_with("rename from ") {
|
||||
let from = line.trim_start_matches("rename from ").trim();
|
||||
change.change = ChangeKind::Renamed {
|
||||
from: strip_path_prefix(from, ""),
|
||||
};
|
||||
change.old_path = strip_path_prefix(from, "");
|
||||
} else if line.starts_with("rename to ") {
|
||||
let to = line.trim_start_matches("rename to ").trim();
|
||||
change.new_path = strip_path_prefix(to, "");
|
||||
} else if line.starts_with("--- ") {
|
||||
let old = line.trim_start_matches("--- ").trim();
|
||||
if old.starts_with('a') {
|
||||
change.old_path = strip_path_prefix(old, "a/");
|
||||
}
|
||||
} else if line.starts_with("+++ ") {
|
||||
let new = line.trim_start_matches("+++ ").trim();
|
||||
if new.starts_with('b') {
|
||||
change.new_path = strip_path_prefix(new, "b/");
|
||||
}
|
||||
} else if line.starts_with('+') && !line.starts_with("+++") {
|
||||
change.additions += 1;
|
||||
} else if line.starts_with('-') && !line.starts_with("---") {
|
||||
change.deletions += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(change) = current.take() {
|
||||
changes.push(change);
|
||||
}
|
||||
|
||||
changes
|
||||
}
|
||||
|
||||
fn format_change_summary(change: &FileChange) -> String {
|
||||
match &change.change {
|
||||
ChangeKind::Added => format!("add {} (+{})", change.primary_path(), change.additions),
|
||||
ChangeKind::Removed => format!("remove {} (-{})", change.primary_path(), change.deletions),
|
||||
ChangeKind::Renamed { from } => format!(
|
||||
"rename {} → {} (+{}, -{})",
|
||||
from,
|
||||
change.primary_path(),
|
||||
change.additions,
|
||||
change.deletions
|
||||
),
|
||||
ChangeKind::Modified => format!(
|
||||
"update {} (+{}, -{})",
|
||||
change.primary_path(),
|
||||
change.additions,
|
||||
change.deletions
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn select_conventional_prefix(metrics: &DiffMetrics) -> &'static str {
|
||||
if metrics.changed_files == 0 {
|
||||
return "chore:";
|
||||
}
|
||||
if metrics.doc_files > 0 && metrics.code_files == 0 && metrics.test_files == 0 {
|
||||
"docs:"
|
||||
} else if metrics.test_files > 0 && metrics.code_files == 0 && metrics.doc_files == 0 {
|
||||
"test:"
|
||||
} else if metrics.config_files > 0 && metrics.code_files == 0 {
|
||||
"chore:"
|
||||
} else if metrics.total_deletions > metrics.total_additions
|
||||
&& metrics.doc_files == 0
|
||||
&& metrics.test_files == 0
|
||||
{
|
||||
"refactor:"
|
||||
} else {
|
||||
"feat:"
|
||||
}
|
||||
}
|
||||
|
||||
fn build_commit_sections(metrics: &DiffMetrics) -> Vec<CommitTemplateSection> {
|
||||
let mut sections = Vec::new();
|
||||
let tests_label = if metrics.has_tests() { "[x]" } else { "[ ]" };
|
||||
let docs_label = if metrics.has_docs() { "[x]" } else { "[ ]" };
|
||||
|
||||
sections.push(CommitTemplateSection {
|
||||
title: "Testing".to_string(),
|
||||
lines: vec![
|
||||
format!("{} unit tests", tests_label),
|
||||
format!("{} integration / e2e", tests_label),
|
||||
"[ ] lint / fmt".to_string(),
|
||||
],
|
||||
});
|
||||
|
||||
sections.push(CommitTemplateSection {
|
||||
title: "Documentation".to_string(),
|
||||
lines: vec![
|
||||
format!("{} docs updated", docs_label),
|
||||
"[ ] release notes".to_string(),
|
||||
],
|
||||
});
|
||||
|
||||
sections
|
||||
}
|
||||
|
||||
fn testing_summary(sections: &[CommitTemplateSection]) -> String {
|
||||
sections
|
||||
.iter()
|
||||
.flat_map(|section| §ion.lines)
|
||||
.filter(|line| line.contains("tests"))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ")
|
||||
}
|
||||
|
||||
fn build_highlights(changes: &[FileChange], metrics: &DiffMetrics) -> Vec<String> {
|
||||
let mut highlights = Vec::new();
|
||||
if metrics.code_files > 0 {
|
||||
highlights.push(format!(
|
||||
"{} code files modified (+{}, -{})",
|
||||
metrics.code_files, metrics.total_additions, metrics.total_deletions
|
||||
));
|
||||
}
|
||||
if metrics.has_tests() {
|
||||
highlights.push(format!("{} test files updated", metrics.test_files));
|
||||
} else if metrics.code_files > 0 {
|
||||
highlights.push("No test files updated; consider adding coverage.".to_string());
|
||||
}
|
||||
if metrics.has_docs() {
|
||||
highlights.push(format!("{} documentation files updated", metrics.doc_files));
|
||||
}
|
||||
|
||||
for change in changes
|
||||
.iter()
|
||||
.filter(|change| change.additions + change.deletions > 400)
|
||||
{
|
||||
highlights.push(format!(
|
||||
"Large change in {} ({:+} / {:-})",
|
||||
change.primary_path(),
|
||||
change.additions,
|
||||
change.deletions
|
||||
));
|
||||
}
|
||||
|
||||
highlights
|
||||
}
|
||||
|
||||
fn analyze_findings(
|
||||
diff: &str,
|
||||
changes: &[FileChange],
|
||||
metrics: &DiffMetrics,
|
||||
) -> Vec<ReviewFinding> {
|
||||
let mut findings = Vec::new();
|
||||
if !metrics.has_tests() && metrics.code_files > 0 {
|
||||
findings.push(
|
||||
ReviewFinding::new(
|
||||
ReviewSeverity::Medium,
|
||||
"Code changes detected without accompanying tests.",
|
||||
)
|
||||
.with_location("Consider adding unit or integration coverage."),
|
||||
);
|
||||
}
|
||||
|
||||
let mut risky_locations: Vec<String> = Vec::new();
|
||||
for change in changes.iter().filter(|change| change.is_code()) {
|
||||
if change.additions > 0 && change.deletions == 0 && change.additions > 200 {
|
||||
findings.push(
|
||||
ReviewFinding::new(
|
||||
ReviewSeverity::Low,
|
||||
format!("Large addition in {}", change.primary_path()),
|
||||
)
|
||||
.with_location(format!("{} lines added", change.additions)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for line in diff.lines() {
|
||||
if line.starts_with('+') {
|
||||
if line.contains("unwrap(") || line.contains(".expect(") {
|
||||
risky_locations.push(line.trim_start_matches('+').trim().to_string());
|
||||
} else if line.contains("unsafe ") {
|
||||
findings.push(
|
||||
ReviewFinding::new(
|
||||
ReviewSeverity::High,
|
||||
"Usage of `unsafe` detected; ensure invariants are documented.",
|
||||
)
|
||||
.with_location(line.trim()),
|
||||
);
|
||||
} else if line.contains("todo!") || line.contains("unimplemented!") {
|
||||
findings.push(
|
||||
ReviewFinding::new(
|
||||
ReviewSeverity::Medium,
|
||||
"TODO/unimplemented marker introduced.",
|
||||
)
|
||||
.with_location(line.trim()),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !risky_locations.is_empty() {
|
||||
findings.push(ReviewFinding {
|
||||
severity: ReviewSeverity::Low,
|
||||
message: "New unwrap()/expect() calls introduced; confirm they are infallible."
|
||||
.to_string(),
|
||||
locations: risky_locations,
|
||||
});
|
||||
}
|
||||
|
||||
findings
|
||||
}
|
||||
|
||||
fn build_review_checklist(
|
||||
metrics: &DiffMetrics,
|
||||
changes: &[FileChange],
|
||||
) -> Vec<ReviewChecklistItem> {
|
||||
let mut checklist = Vec::new();
|
||||
checklist.push(ReviewChecklistItem::new(
|
||||
"Tests cover the change surface",
|
||||
metrics.has_tests(),
|
||||
));
|
||||
checklist.push(ReviewChecklistItem::new(
|
||||
"Documentation updated if behaviour changed",
|
||||
metrics.has_docs(),
|
||||
));
|
||||
|
||||
let includes_release_artifacts = changes.iter().any(|change| {
|
||||
let path = change.primary_path();
|
||||
path.ends_with("CHANGELOG.md") || path.contains("release")
|
||||
});
|
||||
if !includes_release_artifacts {
|
||||
checklist.push(ReviewChecklistItem::new(
|
||||
"Changelog or release notes updated if required",
|
||||
false,
|
||||
));
|
||||
}
|
||||
|
||||
checklist
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
// Git helpers
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
|
||||
fn discover_repo_root(path: &Path) -> Result<PathBuf> {
|
||||
let output = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--show-toplevel")
|
||||
.current_dir(path)
|
||||
.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(Error::InvalidInput(
|
||||
"The current directory is not inside a git repository.".to_string(),
|
||||
));
|
||||
}
|
||||
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||
Ok(PathBuf::from(stdout.trim()))
|
||||
}
|
||||
|
||||
fn capture_diff(root: &Path, mode: DiffCaptureMode<'_>) -> Result<String> {
|
||||
let mut cmd = Command::new("git");
|
||||
cmd.current_dir(root);
|
||||
match mode {
|
||||
DiffCaptureMode::Staged => {
|
||||
cmd.args(["diff", "--cached", "--unified=3", "--no-color"]);
|
||||
}
|
||||
DiffCaptureMode::WorkingTree => {
|
||||
cmd.args(["diff", "--unified=3", "--no-color"]);
|
||||
}
|
||||
DiffCaptureMode::Range { base, head } => {
|
||||
cmd.args([
|
||||
"diff",
|
||||
"--unified=3",
|
||||
"--no-color",
|
||||
&format!("{base}..{head}"),
|
||||
]);
|
||||
}
|
||||
}
|
||||
let output = cmd.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(Error::Unknown(format!(
|
||||
"git diff exited with status {}",
|
||||
output.status
|
||||
)));
|
||||
}
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
}
|
||||
|
||||
fn capture_range_diff(root: &Path, base: &str, head: &str) -> Result<String> {
|
||||
let output = Command::new("git")
|
||||
.args([
|
||||
"diff",
|
||||
"--unified=3",
|
||||
"--no-color",
|
||||
&format!("{base}..{head}"),
|
||||
])
|
||||
.current_dir(root)
|
||||
.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(Error::Unknown(format!(
|
||||
"git diff exited with status {}",
|
||||
output.status
|
||||
)));
|
||||
}
|
||||
Ok(String::from_utf8_lossy(&output.stdout).to_string())
|
||||
}
|
||||
|
||||
fn resolve_merge_base(root: &Path, base: &str, head: &str) -> Result<String> {
|
||||
let output = Command::new("git")
|
||||
.args(["merge-base", base, head])
|
||||
.current_dir(root)
|
||||
.output()?;
|
||||
if !output.status.success() {
|
||||
return Err(Error::Unknown(format!(
|
||||
"git merge-base exited with status {}",
|
||||
output.status
|
||||
)));
|
||||
}
|
||||
Ok(String::from_utf8_lossy(&output.stdout).trim().to_string())
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
// Utility helpers
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
|
||||
fn strip_path_prefix(value: &str, prefix: &str) -> String {
|
||||
value
|
||||
.trim()
|
||||
.trim_matches('"')
|
||||
.trim_start_matches(prefix)
|
||||
.trim_start_matches("./")
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn count_symbol(diff: &str, symbol: char) -> usize {
|
||||
diff.lines()
|
||||
.filter(|line| {
|
||||
line.starts_with(symbol)
|
||||
&& !matches!(
|
||||
(symbol, line.chars().nth(1), line.chars().nth(2)),
|
||||
('+', Some('+'), Some('+')) | ('-', Some('-'), Some('-'))
|
||||
)
|
||||
})
|
||||
.count()
|
||||
}
|
||||
|
||||
fn count_files(diff: &str) -> usize {
|
||||
diff.lines()
|
||||
.filter(|line| line.starts_with("diff --git"))
|
||||
.count()
|
||||
}
|
||||
|
||||
static FILE_TEST_HINTS: [&str; 5] = ["tests/", "_test.", "test/", "spec/", "fixtures/"];
|
||||
static DOC_EXTENSIONS: [&str; 6] = [".md", ".rst", ".adoc", ".txt", ".mdx", ".markdown"];
|
||||
static CONFIG_EXTENSIONS: [&str; 6] = [".toml", ".yaml", ".yml", ".json", ".ini", ".conf"];
|
||||
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
// Tests
|
||||
// -------------------------------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
const SAMPLE_DIFF: &str = r#"diff --git a/src/foo.rs b/src/foo.rs
|
||||
index e69de29..4b825dc 100644
|
||||
--- a/src/foo.rs
|
||||
+++ b/src/foo.rs
|
||||
@@
|
||||
+pub fn add(left: i32, right: i32) -> i32 {
|
||||
+ left + right
|
||||
+}
|
||||
diff --git a/tests/foo_test.rs b/tests/foo_test.rs
|
||||
new file mode 100644
|
||||
index 0000000..bf3b82c
|
||||
--- /dev/null
|
||||
+++ b/tests/foo_test.rs
|
||||
@@
|
||||
+#[test]
|
||||
+fn add_adds_numbers() {
|
||||
+ assert_eq!(crate::add(2, 2), 4);
|
||||
+}
|
||||
diff --git a/README.md b/README.md
|
||||
index 4b825dc..c8f2615 100644
|
||||
--- a/README.md
|
||||
+++ b/README.md
|
||||
@@
|
||||
-# Owlen
|
||||
+# Owlen
|
||||
+Updated docs
|
||||
"#;
|
||||
|
||||
#[test]
|
||||
fn commit_template_infers_prefix_and_sections() {
|
||||
let template = CommitTemplate::from_diff(SAMPLE_DIFF);
|
||||
assert_eq!(template.prefix, "feat:");
|
||||
assert_eq!(template.summary.len(), 3);
|
||||
assert_eq!(template.sections.len(), 2);
|
||||
let tests_section = template
|
||||
.sections
|
||||
.iter()
|
||||
.find(|section| section.title == "Testing")
|
||||
.expect("testing section");
|
||||
assert!(tests_section.lines.iter().any(|line| line.contains("[x]")));
|
||||
let markdown = template.render_markdown();
|
||||
assert!(markdown.contains("Summary:"));
|
||||
assert!(markdown.contains("tests"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn review_highlights_tests_gap() {
|
||||
let diff = r#"diff --git a/src/lib.rs b/src/lib.rs
|
||||
index e69de29..4b825dc 100644
|
||||
--- a/src/lib.rs
|
||||
+++ b/src/lib.rs
|
||||
@@
|
||||
+pub fn risky() {
|
||||
+ let value = std::env::var("MISSING").unwrap();
|
||||
+ println!("{}", value);
|
||||
+}
|
||||
"#;
|
||||
let context = PullRequestContext {
|
||||
title: "Test PR".to_string(),
|
||||
body: None,
|
||||
author: Some("demo".to_string()),
|
||||
base_branch: "main".to_string(),
|
||||
head_branch: "feature".to_string(),
|
||||
additions: 3,
|
||||
deletions: 0,
|
||||
changed_files: 1,
|
||||
html_url: None,
|
||||
};
|
||||
let review = PullRequestReview::from_diff(context, diff);
|
||||
assert!(
|
||||
review
|
||||
.findings
|
||||
.iter()
|
||||
.any(|finding| finding.severity == ReviewSeverity::Medium)
|
||||
);
|
||||
assert!(
|
||||
review
|
||||
.findings
|
||||
.iter()
|
||||
.any(|finding| finding.message.contains("unwrap"))
|
||||
);
|
||||
let markdown = review.render_markdown();
|
||||
assert!(markdown.contains("Summary"));
|
||||
assert!(markdown.contains("Checklist"));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,312 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::encryption::VaultHandle;
|
||||
use crate::tools::canonical_tool_name;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConsentRequest {
|
||||
pub tool_name: String,
|
||||
}
|
||||
|
||||
/// Scope of consent grant
|
||||
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
|
||||
pub enum ConsentScope {
|
||||
/// Grant only for this single operation
|
||||
Once,
|
||||
/// Grant for the duration of the current session
|
||||
Session,
|
||||
/// Grant permanently (persisted across sessions)
|
||||
Permanent,
|
||||
/// Explicitly denied
|
||||
Denied,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct ConsentRecord {
|
||||
pub tool_name: String,
|
||||
pub scope: ConsentScope,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub data_types: Vec<String>,
|
||||
pub external_endpoints: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct ConsentManager {
|
||||
/// Permanent consent records (persisted to vault)
|
||||
permanent_records: HashMap<String, ConsentRecord>,
|
||||
/// Session-scoped consent (cleared on manager drop or explicit clear)
|
||||
#[serde(skip)]
|
||||
session_records: HashMap<String, ConsentRecord>,
|
||||
/// Once-scoped consent (used once then cleared)
|
||||
#[serde(skip)]
|
||||
once_records: HashMap<String, ConsentRecord>,
|
||||
/// Pending consent requests (to prevent duplicate prompts)
|
||||
#[serde(skip)]
|
||||
pending_requests: HashMap<String, ()>,
|
||||
}
|
||||
|
||||
impl ConsentManager {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Load consent records from vault storage
|
||||
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
||||
let guard = vault.lock().expect("Vault mutex poisoned");
|
||||
if let Some(permanent_records) =
|
||||
guard
|
||||
.settings()
|
||||
.get("consent_records")
|
||||
.and_then(|consent_data| {
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
.ok()
|
||||
})
|
||||
{
|
||||
return Self {
|
||||
permanent_records,
|
||||
session_records: HashMap::new(),
|
||||
once_records: HashMap::new(),
|
||||
pending_requests: HashMap::new(),
|
||||
};
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Persist permanent consent records to vault storage
|
||||
pub fn persist_to_vault(&self, vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Result<()> {
|
||||
let mut guard = vault.lock().expect("Vault mutex poisoned");
|
||||
let consent_json = serde_json::to_value(&self.permanent_records)?;
|
||||
guard
|
||||
.settings_mut()
|
||||
.insert("consent_records".to_string(), consent_json);
|
||||
guard.persist()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn request_consent(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Result<ConsentScope> {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
|
||||
// Check if already granted permanently
|
||||
if self
|
||||
.permanent_records
|
||||
.get(canonical)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Permanent)
|
||||
{
|
||||
return Ok(ConsentScope::Permanent);
|
||||
}
|
||||
|
||||
// Check if granted for session
|
||||
if self
|
||||
.session_records
|
||||
.get(canonical)
|
||||
.is_some_and(|existing| existing.scope == ConsentScope::Session)
|
||||
{
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
|
||||
// Check if request is already pending (prevent duplicate prompts)
|
||||
if self.pending_requests.contains_key(canonical) {
|
||||
// Wait for the other prompt to complete by returning denied temporarily
|
||||
// The caller should retry after a short delay
|
||||
return Ok(ConsentScope::Denied);
|
||||
}
|
||||
|
||||
// Mark as pending
|
||||
self.pending_requests.insert(canonical.to_string(), ());
|
||||
|
||||
// Show consent dialog and get scope
|
||||
let scope = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
|
||||
|
||||
// Remove from pending
|
||||
self.pending_requests.remove(canonical);
|
||||
|
||||
// Create record based on scope
|
||||
let record = ConsentRecord {
|
||||
tool_name: canonical.to_string(),
|
||||
scope: scope.clone(),
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
external_endpoints: endpoints,
|
||||
};
|
||||
|
||||
// Store in appropriate location
|
||||
match scope {
|
||||
ConsentScope::Permanent => {
|
||||
self.permanent_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Session => {
|
||||
self.session_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Once | ConsentScope::Denied => {
|
||||
// Don't store, just return the decision
|
||||
}
|
||||
}
|
||||
|
||||
Ok(scope)
|
||||
}
|
||||
|
||||
/// Grant consent programmatically (for TUI or automated flows)
|
||||
pub fn grant_consent(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) {
|
||||
self.grant_consent_with_scope(tool_name, data_types, endpoints, ConsentScope::Permanent);
|
||||
}
|
||||
|
||||
/// Grant consent with specific scope
|
||||
pub fn grant_consent_with_scope(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
scope: ConsentScope,
|
||||
) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
let record = ConsentRecord {
|
||||
tool_name: canonical.to_string(),
|
||||
scope: scope.clone(),
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
external_endpoints: endpoints,
|
||||
};
|
||||
|
||||
match scope {
|
||||
ConsentScope::Permanent => {
|
||||
self.permanent_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Session => {
|
||||
self.session_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Once => {
|
||||
self.once_records.insert(canonical.to_string(), record);
|
||||
}
|
||||
ConsentScope::Denied => {} // Denied is not stored
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
|
||||
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
if self.has_consent(canonical) {
|
||||
None
|
||||
} else {
|
||||
Some(ConsentRequest {
|
||||
tool_name: canonical.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_consent(&self, tool_name: &str) -> bool {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
// Check permanent first, then session, then once
|
||||
self.permanent_records
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Permanent)
|
||||
.or_else(|| {
|
||||
self.session_records
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Session)
|
||||
})
|
||||
.or_else(|| {
|
||||
self.once_records
|
||||
.get(canonical)
|
||||
.map(|r| r.scope == ConsentScope::Once)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Consume "once" consent for a tool (clears it after first use)
|
||||
pub fn consume_once_consent(&mut self, tool_name: &str) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
self.once_records.remove(canonical);
|
||||
}
|
||||
|
||||
pub fn revoke_consent(&mut self, tool_name: &str) {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
self.permanent_records.remove(canonical);
|
||||
self.session_records.remove(canonical);
|
||||
self.once_records.remove(canonical);
|
||||
}
|
||||
|
||||
pub fn clear_all_consent(&mut self) {
|
||||
self.permanent_records.clear();
|
||||
self.session_records.clear();
|
||||
self.once_records.clear();
|
||||
}
|
||||
|
||||
/// Clear only session-scoped consent (useful when starting new session)
|
||||
pub fn clear_session_consent(&mut self) {
|
||||
self.session_records.clear();
|
||||
self.once_records.clear(); // Also clear once consent on session clear
|
||||
}
|
||||
|
||||
/// Check if consent is needed for a tool (non-blocking)
|
||||
/// Returns Some with consent details if needed, None if already granted
|
||||
pub fn check_if_consent_needed(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Option<(String, Vec<String>, Vec<String>)> {
|
||||
let canonical = canonical_tool_name(tool_name);
|
||||
if self.has_consent(canonical) {
|
||||
return None;
|
||||
}
|
||||
Some((canonical.to_string(), data_types, endpoints))
|
||||
}
|
||||
|
||||
fn show_consent_dialog(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
data_types: &[String],
|
||||
endpoints: &[String],
|
||||
) -> Result<ConsentScope> {
|
||||
// TEMPORARY: Auto-grant session consent when not in a proper terminal (TUI mode)
|
||||
// TODO: Integrate consent UI into the TUI event loop
|
||||
use std::io::IsTerminal;
|
||||
if !io::stdin().is_terminal() || std::env::var("OWLEN_AUTO_CONSENT").is_ok() {
|
||||
eprintln!("Auto-granting session consent for {} (TUI mode)", tool_name);
|
||||
return Ok(ConsentScope::Session);
|
||||
}
|
||||
|
||||
println!("\n╔══════════════════════════════════════════════════╗");
|
||||
println!("║ 🔒 PRIVACY CONSENT REQUIRED 🔒 ║");
|
||||
println!("╚══════════════════════════════════════════════════╝");
|
||||
println!();
|
||||
println!("Tool: {}", tool_name);
|
||||
println!("Data: {}", data_types.join(", "));
|
||||
println!("Endpoints: {}", endpoints.join(", "));
|
||||
println!();
|
||||
println!("Choose consent scope:");
|
||||
println!(" [1] Allow once - Grant only for this operation");
|
||||
println!(" [2] Allow session - Grant for current session");
|
||||
println!(" [3] Allow always - Grant permanently");
|
||||
println!(" [4] Deny - Reject this operation");
|
||||
println!();
|
||||
print!("Enter choice (1-4) [default: 4]: ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
|
||||
match input.trim() {
|
||||
"1" => Ok(ConsentScope::Once),
|
||||
"2" => Ok(ConsentScope::Session),
|
||||
"3" => Ok(ConsentScope::Permanent),
|
||||
_ => Ok(ConsentScope::Denied),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,448 +0,0 @@
|
||||
use crate::Result;
|
||||
use crate::storage::StorageManager;
|
||||
use crate::types::{Conversation, Message, MessageAttachment};
|
||||
use serde_json::{Number, Value};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::time::{Duration, Instant};
|
||||
use uuid::Uuid;
|
||||
|
||||
const STREAMING_FLAG: &str = "streaming";
|
||||
const LAST_CHUNK_TS: &str = "last_chunk_ts";
|
||||
const PLACEHOLDER_FLAG: &str = "placeholder";
|
||||
|
||||
/// Manage active and historical conversations, including streaming updates.
|
||||
pub struct ConversationManager {
|
||||
active: Conversation,
|
||||
history: VecDeque<Conversation>,
|
||||
message_index: HashMap<Uuid, usize>,
|
||||
streaming: HashMap<Uuid, StreamingMetadata>,
|
||||
max_history: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamingMetadata {
|
||||
started: Instant,
|
||||
last_update: Instant,
|
||||
}
|
||||
|
||||
impl ConversationManager {
|
||||
/// Create a new conversation manager with a default model
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self::with_history_capacity(model, 32)
|
||||
}
|
||||
|
||||
/// Create with explicit history capacity
|
||||
pub fn with_history_capacity(model: impl Into<String>, max_history: usize) -> Self {
|
||||
let conversation = Conversation::new(model.into());
|
||||
Self {
|
||||
active: conversation,
|
||||
history: VecDeque::new(),
|
||||
message_index: HashMap::new(),
|
||||
streaming: HashMap::new(),
|
||||
max_history: max_history.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Access the active conversation
|
||||
pub fn active(&self) -> &Conversation {
|
||||
&self.active
|
||||
}
|
||||
|
||||
/// Public mutable access to the active conversation
|
||||
pub fn active_mut(&mut self) -> &mut Conversation {
|
||||
&mut self.active
|
||||
}
|
||||
|
||||
/// Replace the active conversation with a provided one, archiving the existing conversation if it contains data
|
||||
pub fn load(&mut self, conversation: Conversation) {
|
||||
if !self.active.messages.is_empty() {
|
||||
self.archive_active();
|
||||
}
|
||||
|
||||
self.message_index.clear();
|
||||
for (idx, message) in conversation.messages.iter().enumerate() {
|
||||
self.message_index.insert(message.id, idx);
|
||||
}
|
||||
|
||||
self.stream_reset();
|
||||
self.active = conversation;
|
||||
}
|
||||
|
||||
/// Start a brand new conversation, archiving the previous one
|
||||
pub fn start_new(&mut self, model: Option<String>, name: Option<String>) {
|
||||
self.archive_active();
|
||||
let model = model.unwrap_or_else(|| self.active.model.clone());
|
||||
self.active = Conversation::new(model);
|
||||
self.active.name = name;
|
||||
self.message_index.clear();
|
||||
self.stream_reset();
|
||||
}
|
||||
|
||||
/// Archive the active conversation into history
|
||||
pub fn archive_active(&mut self) {
|
||||
if self.active.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut archived = self.active.clone();
|
||||
archived.updated_at = std::time::SystemTime::now();
|
||||
self.history.push_front(archived);
|
||||
|
||||
while self.history.len() > self.max_history {
|
||||
self.history.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
/// Get immutable history
|
||||
pub fn history(&self) -> impl Iterator<Item = &Conversation> {
|
||||
self.history.iter()
|
||||
}
|
||||
|
||||
/// Add a user message and return its identifier
|
||||
pub fn push_user_message(&mut self, content: impl Into<String>) -> Uuid {
|
||||
let message = Message::user(content.into());
|
||||
self.register_message(message)
|
||||
}
|
||||
|
||||
/// Add a user message that includes rich attachments.
|
||||
pub fn push_user_message_with_attachments(
|
||||
&mut self,
|
||||
content: impl Into<String>,
|
||||
attachments: Vec<MessageAttachment>,
|
||||
) -> Uuid {
|
||||
let message = Message::user(content.into()).with_attachments(attachments);
|
||||
self.register_message(message)
|
||||
}
|
||||
|
||||
/// Add a system message and return its identifier
|
||||
pub fn push_system_message(&mut self, content: impl Into<String>) -> Uuid {
|
||||
let message = Message::system(content.into());
|
||||
self.register_message(message)
|
||||
}
|
||||
|
||||
/// Add an assistant message (non-streaming) and return its identifier
|
||||
pub fn push_assistant_message(&mut self, content: impl Into<String>) -> Uuid {
|
||||
let message = Message::assistant(content.into());
|
||||
self.register_message(message)
|
||||
}
|
||||
|
||||
/// Push an arbitrary message into the active conversation
|
||||
pub fn push_message(&mut self, message: Message) -> Uuid {
|
||||
self.register_message(message)
|
||||
}
|
||||
|
||||
/// Start tracking a streaming assistant response, returning the message id to update
|
||||
pub fn start_streaming_response(&mut self) -> Uuid {
|
||||
let mut message = Message::assistant(String::new());
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(true));
|
||||
let id = message.id;
|
||||
self.register_message(message);
|
||||
self.streaming.insert(
|
||||
id,
|
||||
StreamingMetadata {
|
||||
started: Instant::now(),
|
||||
last_update: Instant::now(),
|
||||
},
|
||||
);
|
||||
id
|
||||
}
|
||||
|
||||
/// Append streaming content to an assistant message
|
||||
pub fn append_stream_chunk(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
chunk: &str,
|
||||
is_final: bool,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
let conversation = self.active_mut();
|
||||
if let Some(message) = conversation.messages.get_mut(index) {
|
||||
let was_placeholder = message
|
||||
.metadata
|
||||
.remove(PLACEHOLDER_FLAG)
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if was_placeholder {
|
||||
message.content.clear();
|
||||
}
|
||||
|
||||
if !chunk.is_empty() {
|
||||
message.content.push_str(chunk);
|
||||
}
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
message.metadata.insert(
|
||||
LAST_CHUNK_TS.to_string(),
|
||||
Value::Number(Number::from(millis)),
|
||||
);
|
||||
|
||||
if is_final {
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||
self.streaming.remove(&message_id);
|
||||
} else if let Some(info) = self.streaming.get_mut(&message_id) {
|
||||
info.last_update = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Replace the current streaming content for a message.
|
||||
pub fn set_stream_content(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
content: impl Into<String>,
|
||||
is_final: bool,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
let conversation = self.active_mut();
|
||||
if let Some(message) = conversation.messages.get_mut(index) {
|
||||
message.content = content.into();
|
||||
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
message.metadata.insert(
|
||||
LAST_CHUNK_TS.to_string(),
|
||||
Value::Number(Number::from(millis)),
|
||||
);
|
||||
|
||||
if is_final {
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||
self.streaming.remove(&message_id);
|
||||
} else if let Some(info) = self.streaming.get_mut(&message_id) {
|
||||
info.last_update = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set placeholder text for a streaming message
|
||||
pub fn set_stream_placeholder(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
text: impl Into<String>,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
message.content = text.into();
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
message
|
||||
.metadata
|
||||
.insert(PLACEHOLDER_FLAG.to_string(), Value::Bool(true));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn cancel_stream(&mut self, message_id: Uuid, notice: impl Into<String>) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
message.content = notice.into();
|
||||
message.timestamp = std::time::SystemTime::now();
|
||||
message
|
||||
.metadata
|
||||
.insert(STREAMING_FLAG.to_string(), Value::Bool(false));
|
||||
message.metadata.remove(PLACEHOLDER_FLAG);
|
||||
let millis = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_millis() as u64;
|
||||
message.metadata.insert(
|
||||
LAST_CHUNK_TS.to_string(),
|
||||
Value::Number(Number::from(millis)),
|
||||
);
|
||||
}
|
||||
|
||||
self.streaming.remove(&message_id);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set tool calls on a streaming message
|
||||
pub fn set_tool_calls_on_message(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
tool_calls: Vec<crate::types::ToolCall>,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
if tool_calls.is_empty() {
|
||||
message.tool_calls = None;
|
||||
} else {
|
||||
message.tool_calls = Some(tool_calls);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the active model (used when user changes model mid session)
|
||||
pub fn set_model(&mut self, model: impl Into<String>) {
|
||||
self.active.model = model.into();
|
||||
self.active.updated_at = std::time::SystemTime::now();
|
||||
}
|
||||
|
||||
/// Provide read access to the cached streaming metadata
|
||||
pub fn streaming_metadata(&self, message_id: &Uuid) -> Option<StreamingMetadata> {
|
||||
self.streaming.get(message_id).cloned()
|
||||
}
|
||||
|
||||
/// Remove inactive streaming messages that have stalled beyond the provided timeout
|
||||
pub fn expire_stalled_streams(&mut self, idle_timeout: Duration) -> Vec<Uuid> {
|
||||
let cutoff = Instant::now() - idle_timeout;
|
||||
let mut expired = Vec::new();
|
||||
|
||||
self.streaming.retain(|id, meta| {
|
||||
if meta.last_update < cutoff {
|
||||
expired.push(*id);
|
||||
false
|
||||
} else {
|
||||
true
|
||||
}
|
||||
});
|
||||
|
||||
expired
|
||||
}
|
||||
|
||||
/// Clear all state
|
||||
pub fn clear(&mut self) {
|
||||
self.active.clear();
|
||||
self.history.clear();
|
||||
self.message_index.clear();
|
||||
self.streaming.clear();
|
||||
}
|
||||
|
||||
fn register_message(&mut self, message: Message) -> Uuid {
|
||||
let id = message.id;
|
||||
let idx;
|
||||
{
|
||||
let conversation = self.active_mut();
|
||||
idx = conversation.messages.len();
|
||||
conversation.messages.push(message);
|
||||
conversation.updated_at = std::time::SystemTime::now();
|
||||
}
|
||||
self.message_index.insert(id, idx);
|
||||
id
|
||||
}
|
||||
|
||||
/// Replace the active conversation messages and rebuild internal indexes.
|
||||
pub fn replace_active_messages(&mut self, mut messages: Vec<Message>) {
|
||||
let now = std::time::SystemTime::now();
|
||||
for message in &mut messages {
|
||||
// Ensure message timestamps are not in the far past when rewired.
|
||||
message.timestamp = now;
|
||||
}
|
||||
self.active.messages = messages;
|
||||
self.active.updated_at = now;
|
||||
self.rebuild_index();
|
||||
self.stream_reset();
|
||||
}
|
||||
|
||||
fn rebuild_index(&mut self) {
|
||||
self.message_index.clear();
|
||||
for (idx, message) in self.active.messages.iter().enumerate() {
|
||||
self.message_index.insert(message.id, idx);
|
||||
}
|
||||
}
|
||||
|
||||
fn stream_reset(&mut self) {
|
||||
self.streaming.clear();
|
||||
}
|
||||
|
||||
/// Save the active conversation to disk
|
||||
pub async fn save_active(
|
||||
&self,
|
||||
storage: &StorageManager,
|
||||
name: Option<String>,
|
||||
) -> Result<Uuid> {
|
||||
storage.save_conversation(&self.active, name).await?;
|
||||
Ok(self.active.id)
|
||||
}
|
||||
|
||||
/// Save the active conversation to disk with a description
|
||||
pub async fn save_active_with_description(
|
||||
&self,
|
||||
storage: &StorageManager,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<Uuid> {
|
||||
storage
|
||||
.save_conversation_with_description(&self.active, name, description)
|
||||
.await?;
|
||||
Ok(self.active.id)
|
||||
}
|
||||
|
||||
/// Load a conversation from storage and make it active
|
||||
pub async fn load_saved(&mut self, storage: &StorageManager, id: Uuid) -> Result<()> {
|
||||
let conversation = storage.load_conversation(id).await?;
|
||||
self.load(conversation);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all saved sessions
|
||||
pub async fn list_saved_sessions(
|
||||
storage: &StorageManager,
|
||||
) -> Result<Vec<crate::storage::SessionMeta>> {
|
||||
storage.list_sessions().await
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamingMetadata {
|
||||
/// Duration since the stream started
|
||||
pub fn elapsed(&self) -> Duration {
|
||||
self.started.elapsed()
|
||||
}
|
||||
|
||||
/// Duration since the last chunk was received
|
||||
pub fn idle_duration(&self) -> Duration {
|
||||
self.last_update.elapsed()
|
||||
}
|
||||
|
||||
/// Timestamp when streaming started
|
||||
pub fn started_at(&self) -> Instant {
|
||||
self.started
|
||||
}
|
||||
|
||||
/// Timestamp of most recent update
|
||||
pub fn last_update_at(&self) -> Instant {
|
||||
self.last_update
|
||||
}
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, oauth::OAuthToken, storage::StorageManager};
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct ApiCredentials {
|
||||
pub api_key: String,
|
||||
pub endpoint: String,
|
||||
}
|
||||
|
||||
pub const OLLAMA_CLOUD_CREDENTIAL_ID: &str = "provider_ollama_cloud";
|
||||
|
||||
pub struct CredentialManager {
|
||||
storage: Arc<StorageManager>,
|
||||
master_key: Arc<Vec<u8>>,
|
||||
namespace: String,
|
||||
}
|
||||
|
||||
impl CredentialManager {
|
||||
pub fn new(storage: Arc<StorageManager>, master_key: Arc<Vec<u8>>) -> Self {
|
||||
Self {
|
||||
storage,
|
||||
master_key,
|
||||
namespace: "owlen".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn namespaced_key(&self, tool_name: &str) -> String {
|
||||
format!("{}_{}", self.namespace, tool_name)
|
||||
}
|
||||
|
||||
fn oauth_storage_key(&self, resource: &str) -> String {
|
||||
self.namespaced_key(&format!("oauth_{resource}"))
|
||||
}
|
||||
|
||||
pub async fn store_credentials(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
credentials: &ApiCredentials,
|
||||
) -> Result<()> {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
let payload = serde_json::to_vec(credentials).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to serialize credentials for secure storage: {e}"
|
||||
))
|
||||
})?;
|
||||
self.storage
|
||||
.store_secure_item(&key, &payload, &self.master_key)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn get_credentials(&self, tool_name: &str) -> Result<Option<ApiCredentials>> {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
match self
|
||||
.storage
|
||||
.load_secure_item(&key, &self.master_key)
|
||||
.await?
|
||||
{
|
||||
Some(bytes) => {
|
||||
let creds = serde_json::from_slice(&bytes).map_err(|e| {
|
||||
Error::Storage(format!("Failed to deserialize stored credentials: {e}"))
|
||||
})?;
|
||||
Ok(Some(creds))
|
||||
}
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_credentials(&self, tool_name: &str) -> Result<()> {
|
||||
let key = self.namespaced_key(tool_name);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
|
||||
pub async fn store_oauth_token(&self, resource: &str, token: &OAuthToken) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let payload = serde_json::to_vec(token).map_err(|err| {
|
||||
Error::Storage(format!(
|
||||
"Failed to serialize OAuth token for secure storage: {err}"
|
||||
))
|
||||
})?;
|
||||
self.storage
|
||||
.store_secure_item(&key, &payload, &self.master_key)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn load_oauth_token(&self, resource: &str) -> Result<Option<OAuthToken>> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
let raw = self
|
||||
.storage
|
||||
.load_secure_item(&key, &self.master_key)
|
||||
.await?;
|
||||
if let Some(bytes) = raw {
|
||||
let token = serde_json::from_slice(&bytes).map_err(|err| {
|
||||
Error::Storage(format!("Failed to deserialize stored OAuth token: {err}"))
|
||||
})?;
|
||||
Ok(Some(token))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_oauth_token(&self, resource: &str) -> Result<()> {
|
||||
let key = self.oauth_storage_key(resource);
|
||||
self.storage.delete_secure_item(&key).await
|
||||
}
|
||||
}
|
||||
@@ -1,265 +0,0 @@
|
||||
// TODO: Upgrade to generic-array 1.x to remove deprecation warnings
|
||||
#![allow(deprecated)]
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::fs::{self, OpenOptions};
|
||||
use std::io::{self, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use aes_gcm::{
|
||||
Aes256Gcm, Nonce,
|
||||
aead::{Aead, KeyInit},
|
||||
};
|
||||
use anyhow::{Context, Result, bail};
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
pub struct EncryptedStorage {
|
||||
cipher: Aes256Gcm,
|
||||
storage_path: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EncryptedData {
|
||||
nonce: [u8; 12],
|
||||
ciphertext: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct VaultData {
|
||||
pub master_key: Vec<u8>,
|
||||
#[serde(default)]
|
||||
pub settings: HashMap<String, JsonValue>,
|
||||
}
|
||||
|
||||
pub struct VaultHandle {
|
||||
storage: EncryptedStorage,
|
||||
pub data: VaultData,
|
||||
}
|
||||
|
||||
impl VaultHandle {
|
||||
pub fn master_key(&self) -> &[u8] {
|
||||
&self.data.master_key
|
||||
}
|
||||
|
||||
pub fn settings(&self) -> &HashMap<String, JsonValue> {
|
||||
&self.data.settings
|
||||
}
|
||||
|
||||
pub fn settings_mut(&mut self) -> &mut HashMap<String, JsonValue> {
|
||||
&mut self.data.settings
|
||||
}
|
||||
|
||||
pub fn persist(&self) -> Result<()> {
|
||||
self.storage.store(&self.data)
|
||||
}
|
||||
}
|
||||
|
||||
impl EncryptedStorage {
|
||||
pub fn new(storage_path: PathBuf, key: &[u8]) -> Result<Self> {
|
||||
if key.len() != 32 {
|
||||
bail!(
|
||||
"Invalid key length for encrypted storage ({}). Expected 32 bytes for AES-256.",
|
||||
key.len()
|
||||
);
|
||||
}
|
||||
let cipher = Aes256Gcm::new_from_slice(key)
|
||||
.map_err(|_| anyhow::anyhow!("Invalid key length for AES-256"))?;
|
||||
|
||||
if let Some(parent) = storage_path.parent() {
|
||||
fs::create_dir_all(parent).context("Failed to ensure storage directory exists")?;
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
cipher,
|
||||
storage_path,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn store<T: Serialize>(&self, data: &T) -> Result<()> {
|
||||
let json = serde_json::to_vec(data).context("Failed to serialize data")?;
|
||||
|
||||
let nonce = generate_nonce()?;
|
||||
let nonce_ref = Nonce::from_slice(&nonce);
|
||||
|
||||
let ciphertext = self
|
||||
.cipher
|
||||
.encrypt(nonce_ref, json.as_ref())
|
||||
.map_err(|e| anyhow::anyhow!("Encryption failed: {}", e))?;
|
||||
|
||||
let encrypted_data = EncryptedData { nonce, ciphertext };
|
||||
let encrypted_json = serde_json::to_vec(&encrypted_data)?;
|
||||
|
||||
fs::write(&self.storage_path, encrypted_json).context("Failed to write encrypted data")?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn load<T: for<'de> Deserialize<'de>>(&self) -> Result<T> {
|
||||
let encrypted_json =
|
||||
fs::read(&self.storage_path).context("Failed to read encrypted data")?;
|
||||
|
||||
let encrypted_data: EncryptedData =
|
||||
serde_json::from_slice(&encrypted_json).context("Failed to parse encrypted data")?;
|
||||
|
||||
let nonce_ref = Nonce::from_slice(&encrypted_data.nonce);
|
||||
let plaintext = self
|
||||
.cipher
|
||||
.decrypt(nonce_ref, encrypted_data.ciphertext.as_ref())
|
||||
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))?;
|
||||
|
||||
let data: T =
|
||||
serde_json::from_slice(&plaintext).context("Failed to deserialize decrypted data")?;
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
|
||||
pub fn exists(&self) -> bool {
|
||||
self.storage_path.exists()
|
||||
}
|
||||
|
||||
pub fn delete(&self) -> Result<()> {
|
||||
if self.exists() {
|
||||
fs::remove_file(&self.storage_path).context("Failed to delete encrypted storage")?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn verify_password(&self) -> Result<()> {
|
||||
if !self.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let encrypted_json =
|
||||
fs::read(&self.storage_path).context("Failed to read encrypted data")?;
|
||||
|
||||
if encrypted_json.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let encrypted_data: EncryptedData =
|
||||
serde_json::from_slice(&encrypted_json).context("Failed to parse encrypted data")?;
|
||||
|
||||
let nonce_ref = Nonce::from_slice(&encrypted_data.nonce);
|
||||
self.cipher
|
||||
.decrypt(nonce_ref, encrypted_data.ciphertext.as_ref())
|
||||
.map(|_| ())
|
||||
.map_err(|e| anyhow::anyhow!("Decryption failed: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn unlock(storage_path: PathBuf) -> Result<VaultHandle> {
|
||||
let key = load_or_create_encryption_key(&storage_path)?;
|
||||
let storage = EncryptedStorage::new(storage_path, &key)?;
|
||||
let data = load_or_initialize_vault(&storage)?;
|
||||
Ok(VaultHandle { storage, data })
|
||||
}
|
||||
|
||||
fn load_or_initialize_vault(storage: &EncryptedStorage) -> Result<VaultData> {
|
||||
match storage.load::<VaultData>() {
|
||||
Ok(data) => {
|
||||
if data.master_key.len() != 32 {
|
||||
bail!(
|
||||
"Corrupted vault: master key has invalid length ({}). \
|
||||
Expected 32 bytes for AES-256. Vault cannot be recovered.",
|
||||
data.master_key.len()
|
||||
);
|
||||
}
|
||||
Ok(data)
|
||||
}
|
||||
Err(err) => {
|
||||
if storage.exists() {
|
||||
return Err(err);
|
||||
}
|
||||
let data = VaultData {
|
||||
master_key: generate_master_key()?,
|
||||
..Default::default()
|
||||
};
|
||||
storage.store(&data)?;
|
||||
Ok(data)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn key_path(storage_path: &Path) -> PathBuf {
|
||||
let mut path = storage_path.to_path_buf();
|
||||
path.set_extension("key");
|
||||
path
|
||||
}
|
||||
|
||||
fn load_or_create_encryption_key(storage_path: &Path) -> Result<Vec<u8>> {
|
||||
let key_path = key_path(storage_path);
|
||||
match fs::read(&key_path) {
|
||||
Ok(bytes) => {
|
||||
if bytes.len() == 32 {
|
||||
Ok(bytes)
|
||||
} else {
|
||||
bail!(
|
||||
"Invalid encryption key length stored in {} ({} bytes). Expected 32 bytes.",
|
||||
key_path.display(),
|
||||
bytes.len()
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) if err.kind() == io::ErrorKind::NotFound => {
|
||||
let key = generate_master_key()?;
|
||||
write_key_file(&key_path, &key)?;
|
||||
Ok(key)
|
||||
}
|
||||
Err(err) => Err(err)
|
||||
.with_context(|| format!("Failed to read encryption key from {}", key_path.display())),
|
||||
}
|
||||
}
|
||||
|
||||
fn write_key_file(path: &Path, key: &[u8]) -> Result<()> {
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.with_context(|| format!("Failed to create directory {}", parent.display()))?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(path)
|
||||
.with_context(|| format!("Failed to open encryption key file {}", path.display()))?;
|
||||
file.write_all(key)
|
||||
.with_context(|| format!("Failed to write encryption key file {}", path.display()))?;
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
let mut file = OpenOptions::new()
|
||||
.create(true)
|
||||
.write(true)
|
||||
.truncate(true)
|
||||
.open(path)
|
||||
.with_context(|| format!("Failed to open encryption key file {}", path.display()))?;
|
||||
file.write_all(key)
|
||||
.with_context(|| format!("Failed to write encryption key file {}", path.display()))?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn generate_master_key() -> Result<Vec<u8>> {
|
||||
let mut key = vec![0u8; 32];
|
||||
SystemRandom::new()
|
||||
.fill(&mut key)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to generate master key"))?;
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn generate_nonce() -> Result<[u8; 12]> {
|
||||
let mut nonce = [0u8; 12];
|
||||
let rng = SystemRandom::new();
|
||||
rng.fill(&mut nonce)
|
||||
.map_err(|_| anyhow::anyhow!("Failed to generate nonce"))?;
|
||||
Ok(nonce)
|
||||
}
|
||||
@@ -1,32 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
|
||||
use crate::{
|
||||
Result,
|
||||
llm::ChatStream,
|
||||
mcp::{McpToolCall, McpToolDescriptor, McpToolResponse},
|
||||
types::{ChatRequest, ChatResponse, ModelInfo},
|
||||
};
|
||||
|
||||
/// Object-safe facade for interacting with LLM backends.
|
||||
#[async_trait]
|
||||
pub trait LlmClient: Send + Sync {
|
||||
/// List the models exposed by this client.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Issue a one-shot chat request and wait for the complete response.
|
||||
async fn send_chat(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
/// Stream chat responses incrementally.
|
||||
async fn stream_chat(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
/// Enumerate tools exposed by the backing provider.
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>>;
|
||||
|
||||
/// Invoke a tool exposed by the provider.
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||
}
|
||||
|
||||
/// Convenience alias for trait-object clients.
|
||||
pub type DynLlmClient = Arc<dyn LlmClient>;
|
||||
@@ -1 +0,0 @@
|
||||
pub mod llm_client;
|
||||
@@ -1,112 +0,0 @@
|
||||
use crate::types::Message;
|
||||
use crate::ui::RoleLabelDisplay;
|
||||
|
||||
/// Formats messages for display across different clients.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MessageFormatter {
|
||||
wrap_width: usize,
|
||||
role_label_mode: RoleLabelDisplay,
|
||||
preserve_empty_lines: bool,
|
||||
}
|
||||
|
||||
impl MessageFormatter {
|
||||
/// Create a new formatter
|
||||
pub fn new(wrap_width: usize, role_label_mode: RoleLabelDisplay) -> Self {
|
||||
Self {
|
||||
wrap_width: wrap_width.max(20),
|
||||
role_label_mode,
|
||||
preserve_empty_lines: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Override whether empty lines should be preserved
|
||||
pub fn with_preserve_empty(mut self, preserve: bool) -> Self {
|
||||
self.preserve_empty_lines = preserve;
|
||||
self
|
||||
}
|
||||
|
||||
/// Update the wrap width
|
||||
pub fn set_wrap_width(&mut self, width: usize) {
|
||||
self.wrap_width = width.max(20);
|
||||
}
|
||||
|
||||
/// The configured role label layout preference.
|
||||
pub fn role_label_mode(&self) -> RoleLabelDisplay {
|
||||
self.role_label_mode
|
||||
}
|
||||
|
||||
/// Whether any role label should be shown alongside messages.
|
||||
pub fn show_role_labels(&self) -> bool {
|
||||
!matches!(self.role_label_mode, RoleLabelDisplay::None)
|
||||
}
|
||||
|
||||
/// Update the role label layout preference.
|
||||
pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) {
|
||||
self.role_label_mode = mode;
|
||||
}
|
||||
|
||||
pub fn format_message(&self, message: &Message) -> Vec<String> {
|
||||
message
|
||||
.content
|
||||
.trim()
|
||||
.lines()
|
||||
.map(|s| s.to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Extract thinking content from <think> tags, returning (content_without_think, thinking_content)
|
||||
/// This handles both complete and incomplete (streaming) think tags.
|
||||
pub fn extract_thinking(&self, content: &str) -> (String, Option<String>) {
|
||||
let mut result = String::new();
|
||||
let mut thinking = String::new();
|
||||
let mut current_pos = 0;
|
||||
|
||||
while let Some(start_pos) = content[current_pos..].find("<think>") {
|
||||
let abs_start = current_pos + start_pos;
|
||||
|
||||
// Add content before <think> tag to result
|
||||
result.push_str(&content[current_pos..abs_start]);
|
||||
|
||||
// Find closing tag
|
||||
if let Some(end_pos) = content[abs_start..].find("</think>") {
|
||||
let abs_end = abs_start + end_pos;
|
||||
let think_content = &content[abs_start + 7..abs_end]; // 7 = len("<think>")
|
||||
|
||||
if !thinking.is_empty() {
|
||||
thinking.push_str("\n\n");
|
||||
}
|
||||
thinking.push_str(think_content.trim());
|
||||
|
||||
current_pos = abs_end + 8; // 8 = len("</think>")
|
||||
} else {
|
||||
// Unclosed tag - this is streaming content
|
||||
// Extract everything after <think> as thinking content
|
||||
let think_content = &content[abs_start + 7..]; // 7 = len("<think>")
|
||||
|
||||
if !thinking.is_empty() {
|
||||
thinking.push_str("\n\n");
|
||||
}
|
||||
thinking.push_str(think_content);
|
||||
|
||||
current_pos = content.len();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Add remaining content
|
||||
result.push_str(&content[current_pos..]);
|
||||
|
||||
let thinking_result = if thinking.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(thinking)
|
||||
};
|
||||
|
||||
// If the result is empty but we have thinking content, show a placeholder
|
||||
if result.trim().is_empty() && thinking_result.is_some() {
|
||||
result.push_str("[Thinking...]");
|
||||
}
|
||||
|
||||
(result, thinking_result)
|
||||
}
|
||||
}
|
||||
@@ -1,258 +0,0 @@
|
||||
use crate::automation::repo::{PullRequestContext, summarize_diff};
|
||||
use crate::{Error, Result};
|
||||
use reqwest::header::{ACCEPT, AUTHORIZATION, HeaderValue, USER_AGENT};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DEFAULT_API_ENDPOINT: &str = "https://api.github.com";
|
||||
const USER_AGENT_VALUE: &str = "owlen/0.2";
|
||||
|
||||
/// Lightweight GitHub API client used for repository automation workflows.
|
||||
pub struct GithubClient {
|
||||
client: reqwest::Client,
|
||||
base_url: String,
|
||||
token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct GithubConfig {
|
||||
pub token: Option<String>,
|
||||
pub api_endpoint: Option<String>,
|
||||
}
|
||||
|
||||
impl GithubClient {
|
||||
pub fn new(config: GithubConfig) -> Result<Self> {
|
||||
let client = reqwest::Client::builder()
|
||||
.user_agent(USER_AGENT_VALUE)
|
||||
.build()
|
||||
.map_err(|err| Error::Network(err.to_string()))?;
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: config
|
||||
.api_endpoint
|
||||
.unwrap_or_else(|| DEFAULT_API_ENDPOINT.to_string()),
|
||||
token: config.token,
|
||||
})
|
||||
}
|
||||
|
||||
/// Fetch a pull request, returning diff text along with contextual metadata.
|
||||
pub async fn pull_request(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
) -> Result<PullRequestDetails> {
|
||||
let pr = self.fetch_pull_request(owner, repo, number).await?;
|
||||
let diff = self.fetch_diff(owner, repo, number).await?;
|
||||
let files = self.fetch_files(owner, repo, number).await?;
|
||||
let stats = summarize_diff(&diff);
|
||||
let context = PullRequestContext {
|
||||
title: pr
|
||||
.title
|
||||
.clone()
|
||||
.unwrap_or_else(|| format!("PR #{}", pr.number)),
|
||||
body: pr.body.clone(),
|
||||
author: pr.user.map(|user| user.login),
|
||||
base_branch: pr.base.ref_field,
|
||||
head_branch: pr.head.ref_field,
|
||||
additions: stats.additions as u64,
|
||||
deletions: stats.deletions as u64,
|
||||
changed_files: stats.files as u64,
|
||||
html_url: pr.html_url,
|
||||
};
|
||||
|
||||
Ok(PullRequestDetails {
|
||||
context,
|
||||
diff,
|
||||
files,
|
||||
})
|
||||
}
|
||||
|
||||
async fn fetch_pull_request(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
) -> Result<GitHubPullRequest> {
|
||||
let url = format!(
|
||||
"{}/repos/{}/{}/pulls/{}",
|
||||
self.base_url.trim_end_matches('/'),
|
||||
owner,
|
||||
repo,
|
||||
number
|
||||
);
|
||||
let response = self
|
||||
.request(&url, Some("application/vnd.github+json"))?
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))?;
|
||||
if !response.status().is_success() {
|
||||
return Err(Error::Network(format!(
|
||||
"GitHub returned status {} while fetching pull request",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
response
|
||||
.json::<GitHubPullRequest>()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))
|
||||
}
|
||||
|
||||
async fn fetch_diff(&self, owner: &str, repo: &str, number: u64) -> Result<String> {
|
||||
let url = format!(
|
||||
"{}/repos/{}/{}/pulls/{}",
|
||||
self.base_url.trim_end_matches('/'),
|
||||
owner,
|
||||
repo,
|
||||
number
|
||||
);
|
||||
let response = self
|
||||
.request(&url, Some("application/vnd.github.v3.diff"))?
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))?;
|
||||
if !response.status().is_success() {
|
||||
return Err(Error::Network(format!(
|
||||
"GitHub returned status {} while downloading diff",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))
|
||||
}
|
||||
|
||||
async fn fetch_files(
|
||||
&self,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
) -> Result<Vec<GithubPullFile>> {
|
||||
let mut results = Vec::new();
|
||||
let mut next_url = Some(format!(
|
||||
"{}/repos/{}/{}/pulls/{}/files?per_page=100",
|
||||
self.base_url.trim_end_matches('/'),
|
||||
owner,
|
||||
repo,
|
||||
number
|
||||
));
|
||||
|
||||
while let Some(url) = next_url {
|
||||
let response = self
|
||||
.request(&url, Some("application/vnd.github+json"))?
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))?;
|
||||
if !response.status().is_success() {
|
||||
return Err(Error::Network(format!(
|
||||
"GitHub returned status {} while listing PR files",
|
||||
response.status()
|
||||
)));
|
||||
}
|
||||
let link_header = response.headers().get("link").cloned();
|
||||
let page: Vec<GitHubPullFileApi> = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|err| Error::Network(err.to_string()))?;
|
||||
results.extend(page.into_iter().map(GithubPullFile::from));
|
||||
next_url = next_link(link_header.as_ref());
|
||||
}
|
||||
|
||||
Ok(results)
|
||||
}
|
||||
|
||||
fn request(&self, url: &str, accept: Option<&str>) -> Result<reqwest::RequestBuilder> {
|
||||
let mut builder = self.client.get(url);
|
||||
builder = builder.header(USER_AGENT, USER_AGENT_VALUE);
|
||||
if let Some(token) = &self.token {
|
||||
builder = builder.header(AUTHORIZATION, format!("token {}", token));
|
||||
}
|
||||
if let Some(accept) = accept {
|
||||
builder = builder.header(ACCEPT, accept);
|
||||
}
|
||||
Ok(builder)
|
||||
}
|
||||
}
|
||||
|
||||
/// Rich pull request details used by automation workflows.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PullRequestDetails {
|
||||
pub context: PullRequestContext,
|
||||
pub diff: String,
|
||||
pub files: Vec<GithubPullFile>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct GithubPullFile {
|
||||
pub filename: String,
|
||||
pub status: String,
|
||||
pub additions: u64,
|
||||
pub deletions: u64,
|
||||
pub changes: u64,
|
||||
pub patch: Option<String>,
|
||||
}
|
||||
|
||||
impl From<GitHubPullFileApi> for GithubPullFile {
|
||||
fn from(value: GitHubPullFileApi) -> Self {
|
||||
Self {
|
||||
filename: value.filename,
|
||||
status: value.status,
|
||||
additions: value.additions,
|
||||
deletions: value.deletions,
|
||||
changes: value.changes,
|
||||
patch: value.patch,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GitHubPullRequest {
|
||||
number: u64,
|
||||
title: Option<String>,
|
||||
body: Option<String>,
|
||||
user: Option<GitHubUser>,
|
||||
base: GitRef,
|
||||
head: GitRef,
|
||||
html_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GitHubUser {
|
||||
login: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GitRef {
|
||||
#[serde(rename = "ref")]
|
||||
ref_field: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GitHubPullFileApi {
|
||||
filename: String,
|
||||
status: String,
|
||||
additions: u64,
|
||||
deletions: u64,
|
||||
changes: u64,
|
||||
#[serde(default)]
|
||||
patch: Option<String>,
|
||||
}
|
||||
|
||||
fn next_link(header: Option<&HeaderValue>) -> Option<String> {
|
||||
let header = header?.to_str().ok()?;
|
||||
for part in header.split(',') {
|
||||
let segments: Vec<&str> = part.split(';').collect();
|
||||
if segments.len() < 2 {
|
||||
continue;
|
||||
}
|
||||
let url = segments[0]
|
||||
.trim()
|
||||
.trim_start_matches('<')
|
||||
.trim_end_matches('>');
|
||||
let rel = segments[1].trim();
|
||||
if rel == "rel=\"next\"" {
|
||||
return Some(url.to_string());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
@@ -1,223 +0,0 @@
|
||||
use std::collections::VecDeque;
|
||||
|
||||
/// Text input buffer with history and cursor management.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InputBuffer {
|
||||
buffer: String,
|
||||
cursor: usize,
|
||||
history: VecDeque<String>,
|
||||
history_index: Option<usize>,
|
||||
max_history: usize,
|
||||
pub multiline: bool,
|
||||
tab_width: u8,
|
||||
}
|
||||
|
||||
impl InputBuffer {
|
||||
/// Create a new input buffer
|
||||
pub fn new(max_history: usize, multiline: bool, tab_width: u8) -> Self {
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
cursor: 0,
|
||||
history: VecDeque::with_capacity(max_history.max(1)),
|
||||
history_index: None,
|
||||
max_history: max_history.max(1),
|
||||
multiline,
|
||||
tab_width: tab_width.max(1),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current text
|
||||
pub fn text(&self) -> &str {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
/// Current cursor position
|
||||
pub fn cursor(&self) -> usize {
|
||||
self.cursor
|
||||
}
|
||||
|
||||
/// Replace buffer contents
|
||||
pub fn set_text(&mut self, text: impl Into<String>) {
|
||||
self.buffer = text.into();
|
||||
self.cursor = self.buffer.len();
|
||||
self.history_index = None;
|
||||
}
|
||||
|
||||
/// Clear buffer and reset cursor
|
||||
pub fn clear(&mut self) {
|
||||
self.buffer.clear();
|
||||
self.cursor = 0;
|
||||
self.history_index = None;
|
||||
}
|
||||
|
||||
/// Insert a character at the cursor position
|
||||
pub fn insert_char(&mut self, ch: char) {
|
||||
if ch == '\t' {
|
||||
self.insert_tab();
|
||||
return;
|
||||
}
|
||||
|
||||
self.buffer.insert(self.cursor, ch);
|
||||
self.cursor += ch.len_utf8();
|
||||
}
|
||||
|
||||
/// Insert text at cursor
|
||||
pub fn insert_text(&mut self, text: &str) {
|
||||
self.buffer.insert_str(self.cursor, text);
|
||||
self.cursor += text.len();
|
||||
}
|
||||
|
||||
/// Insert spaces representing a tab
|
||||
pub fn insert_tab(&mut self) {
|
||||
let spaces = " ".repeat(self.tab_width as usize);
|
||||
self.insert_text(&spaces);
|
||||
}
|
||||
|
||||
/// Remove character before cursor
|
||||
pub fn backspace(&mut self) {
|
||||
if self.cursor == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let prev_index = prev_char_boundary(&self.buffer, self.cursor);
|
||||
self.buffer.drain(prev_index..self.cursor);
|
||||
self.cursor = prev_index;
|
||||
}
|
||||
|
||||
/// Remove character at cursor
|
||||
pub fn delete(&mut self) {
|
||||
if self.cursor >= self.buffer.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let next_index = next_char_boundary(&self.buffer, self.cursor);
|
||||
self.buffer.drain(self.cursor..next_index);
|
||||
}
|
||||
|
||||
/// Move cursor left by one grapheme
|
||||
pub fn move_left(&mut self) {
|
||||
if self.cursor == 0 {
|
||||
return;
|
||||
}
|
||||
self.cursor = prev_char_boundary(&self.buffer, self.cursor);
|
||||
}
|
||||
|
||||
/// Move cursor right by one grapheme
|
||||
pub fn move_right(&mut self) {
|
||||
if self.cursor >= self.buffer.len() {
|
||||
return;
|
||||
}
|
||||
self.cursor = next_char_boundary(&self.buffer, self.cursor);
|
||||
}
|
||||
|
||||
/// Move cursor to start of the buffer
|
||||
pub fn move_home(&mut self) {
|
||||
self.cursor = 0;
|
||||
}
|
||||
|
||||
/// Move cursor to end of the buffer
|
||||
pub fn move_end(&mut self) {
|
||||
self.cursor = self.buffer.len();
|
||||
}
|
||||
|
||||
/// Push current buffer into history, clearing the buffer afterwards
|
||||
pub fn commit_to_history(&mut self) -> String {
|
||||
let text = std::mem::take(&mut self.buffer);
|
||||
if !text.trim().is_empty() {
|
||||
self.push_history_entry(text.clone());
|
||||
}
|
||||
self.cursor = 0;
|
||||
self.history_index = None;
|
||||
text
|
||||
}
|
||||
|
||||
/// Navigate to previous history entry
|
||||
pub fn history_previous(&mut self) {
|
||||
if self.history.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let new_index = match self.history_index {
|
||||
Some(idx) if idx + 1 < self.history.len() => idx + 1,
|
||||
None => 0,
|
||||
_ => return,
|
||||
};
|
||||
|
||||
self.history_index = Some(new_index);
|
||||
if let Some(entry) = self.history.get(new_index) {
|
||||
self.buffer = entry.clone();
|
||||
self.cursor = self.buffer.len();
|
||||
}
|
||||
}
|
||||
|
||||
/// Navigate to next history entry
|
||||
pub fn history_next(&mut self) {
|
||||
if self.history.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(idx) = self.history_index {
|
||||
if idx > 0 {
|
||||
let new_idx = idx - 1;
|
||||
self.history_index = Some(new_idx);
|
||||
if let Some(entry) = self.history.get(new_idx) {
|
||||
self.buffer = entry.clone();
|
||||
self.cursor = self.buffer.len();
|
||||
}
|
||||
} else {
|
||||
self.history_index = None;
|
||||
self.buffer.clear();
|
||||
self.cursor = 0;
|
||||
}
|
||||
} else {
|
||||
self.buffer.clear();
|
||||
self.cursor = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Push a new entry into the history buffer, enforcing capacity
|
||||
pub fn push_history_entry(&mut self, entry: String) {
|
||||
if self
|
||||
.history
|
||||
.front()
|
||||
.map(|existing| existing == &entry)
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
self.history.push_front(entry);
|
||||
while self.history.len() > self.max_history {
|
||||
self.history.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear saved input history entries.
|
||||
pub fn clear_history(&mut self) {
|
||||
self.history.clear();
|
||||
self.history_index = None;
|
||||
}
|
||||
}
|
||||
|
||||
fn prev_char_boundary(buffer: &str, cursor: usize) -> usize {
|
||||
buffer[..cursor]
|
||||
.char_indices()
|
||||
.last()
|
||||
.map(|(idx, _)| idx)
|
||||
.unwrap_or(0)
|
||||
}
|
||||
|
||||
fn next_char_boundary(buffer: &str, cursor: usize) -> usize {
|
||||
if cursor >= buffer.len() {
|
||||
return buffer.len();
|
||||
}
|
||||
|
||||
let slice = &buffer[cursor..];
|
||||
let mut iter = slice.char_indices();
|
||||
iter.next();
|
||||
if let Some((idx, _)) = iter.next() {
|
||||
cursor + idx
|
||||
} else {
|
||||
buffer.len()
|
||||
}
|
||||
}
|
||||
@@ -1,123 +0,0 @@
|
||||
//! Core traits and types for OWLEN LLM client
|
||||
//!
|
||||
//! This crate provides the foundational abstractions for building
|
||||
//! LLM providers, routers, and MCP (Model Context Protocol) adapters.
|
||||
|
||||
pub mod agent;
|
||||
pub mod agent_registry;
|
||||
pub mod automation;
|
||||
pub mod config;
|
||||
pub mod consent;
|
||||
pub mod conversation;
|
||||
pub mod credentials;
|
||||
pub mod encryption;
|
||||
pub mod facade;
|
||||
pub mod formatting;
|
||||
pub mod github;
|
||||
pub mod input;
|
||||
pub mod llm;
|
||||
pub mod mcp;
|
||||
pub mod mode;
|
||||
pub mod model;
|
||||
pub mod oauth;
|
||||
pub mod provider;
|
||||
pub mod providers;
|
||||
pub mod router;
|
||||
pub mod sandbox;
|
||||
pub mod session;
|
||||
pub mod state;
|
||||
pub mod storage;
|
||||
pub mod tools;
|
||||
pub mod types;
|
||||
pub mod ui;
|
||||
pub mod usage;
|
||||
pub mod validation;
|
||||
pub mod wrap_cursor;
|
||||
|
||||
// Re-export theme types from owlen-ui-common
|
||||
pub use owlen_ui_common::{
|
||||
Color, NamedColor, Theme, ThemePalette, built_in_themes, default_themes_dir, get_theme,
|
||||
load_all_themes,
|
||||
};
|
||||
|
||||
pub use agent::*;
|
||||
pub use agent_registry::*;
|
||||
pub use automation::*;
|
||||
pub use config::*;
|
||||
pub use consent::*;
|
||||
pub use conversation::*;
|
||||
pub use credentials::*;
|
||||
pub use encryption::*;
|
||||
pub use formatting::*;
|
||||
pub use github::*;
|
||||
pub use input::*;
|
||||
pub use oauth::*;
|
||||
// Export MCP types but exclude test_utils to avoid ambiguity
|
||||
pub use facade::llm_client::*;
|
||||
pub use llm::{
|
||||
ChatStream, LlmProvider, Provider, ProviderConfig, ProviderRegistry, send_via_stream,
|
||||
};
|
||||
pub use mcp::{
|
||||
LocalMcpClient, McpServer, McpToolCall, McpToolDescriptor, McpToolResponse, client, factory,
|
||||
failover, permission, protocol, remote_client,
|
||||
};
|
||||
pub use mode::*;
|
||||
pub use model::*;
|
||||
pub use provider::*;
|
||||
pub use providers::*;
|
||||
pub use router::*;
|
||||
pub use sandbox::*;
|
||||
pub use session::*;
|
||||
pub use state::*;
|
||||
pub use tools::*;
|
||||
pub use usage::*;
|
||||
pub use validation::*;
|
||||
|
||||
/// Result type used throughout the OWLEN ecosystem
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
/// Core error types for OWLEN
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(#[from] anyhow::Error),
|
||||
|
||||
#[error("Provider failure: {0}")]
|
||||
ProviderFailure(provider::ProviderError),
|
||||
|
||||
#[error("Network error: {0}")]
|
||||
Network(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Configuration error: {0}")]
|
||||
Config(String),
|
||||
|
||||
#[error("I/O error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("Invalid input: {0}")]
|
||||
InvalidInput(String),
|
||||
|
||||
#[error("Operation timed out: {0}")]
|
||||
Timeout(String),
|
||||
|
||||
#[error("Serialization error: {0}")]
|
||||
Serialization(#[from] serde_json::Error),
|
||||
|
||||
#[error("Storage error: {0}")]
|
||||
Storage(String),
|
||||
|
||||
#[error("Unknown error: {0}")]
|
||||
Unknown(String),
|
||||
|
||||
#[error("Not implemented: {0}")]
|
||||
NotImplemented(String),
|
||||
|
||||
#[error("Permission denied: {0}")]
|
||||
PermissionDenied(String),
|
||||
|
||||
#[error("Agent execution error: {0}")]
|
||||
Agent(String),
|
||||
}
|
||||
@@ -1,337 +0,0 @@
|
||||
//! LLM provider abstractions and registry.
|
||||
//!
|
||||
//! This module defines the provider trait hierarchy along with helpers that
|
||||
//! make it easy to register concrete LLM backends and access them through
|
||||
//! dynamic dispatch when wiring the application together.
|
||||
|
||||
use crate::{Error, Result, types::*};
|
||||
use anyhow::anyhow;
|
||||
use futures::{Stream, StreamExt};
|
||||
use serde_json::Value;
|
||||
use std::any::Any;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A boxed stream of chat responses produced by a provider.
|
||||
pub type ChatStream = Pin<Box<dyn Stream<Item = Result<ChatResponse>> + Send>>;
|
||||
|
||||
/// Trait implemented by every LLM backend Owlen can speak to.
|
||||
///
|
||||
/// Providers expose both one-shot and streaming prompt APIs. Concrete
|
||||
/// implementations typically live in `crate::providers`.
|
||||
pub trait LlmProvider: Send + Sync + 'static + Any + Sized {
|
||||
/// Stream type returned by [`Self::stream_prompt`].
|
||||
type Stream: Stream<Item = Result<ChatResponse>> + Send + 'static;
|
||||
|
||||
type ListModelsFuture<'a>: Future<Output = Result<Vec<ModelInfo>>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type SendPromptFuture<'a>: Future<Output = Result<ChatResponse>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type StreamPromptFuture<'a>: Future<Output = Result<Self::Stream>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type HealthCheckFuture<'a>: Future<Output = Result<()>> + Send
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
/// Human-readable provider identifier.
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Return metadata on all models exposed by this provider.
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_>;
|
||||
|
||||
/// Issue a prompt and wait for the provider to return the full response.
|
||||
fn send_prompt(&self, request: ChatRequest) -> Self::SendPromptFuture<'_>;
|
||||
|
||||
/// Issue a prompt and receive responses incrementally as a stream.
|
||||
fn stream_prompt(&self, request: ChatRequest) -> Self::StreamPromptFuture<'_>;
|
||||
|
||||
/// Perform a lightweight health check.
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_>;
|
||||
|
||||
/// Provider-specific configuration schema (optional).
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
/// Access the provider as an `Any` for downcasting.
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper that requests a streamed generation and yields the first chunk as a
|
||||
/// regular response. This is handy for providers that only implement the
|
||||
/// streaming API.
|
||||
pub async fn send_via_stream<'a, P>(provider: &'a P, request: ChatRequest) -> Result<ChatResponse>
|
||||
where
|
||||
P: LlmProvider + 'a,
|
||||
{
|
||||
let stream = provider.stream_prompt(request).await?;
|
||||
let mut boxed: ChatStream = Box::pin(stream);
|
||||
match boxed.next().await {
|
||||
Some(Ok(response)) => Ok(response),
|
||||
Some(Err(err)) => Err(err),
|
||||
None => Err(Error::Provider(anyhow!(
|
||||
"Empty chat stream from provider {}",
|
||||
provider.name()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Object-safe wrapper around [`LlmProvider`] for dynamic dispatch scenarios.
|
||||
#[async_trait::async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
fn name(&self) -> &str;
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse>;
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream>;
|
||||
|
||||
async fn health_check(&self) -> Result<()>;
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
serde_json::json!({})
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync);
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T> Provider for T
|
||||
where
|
||||
T: LlmProvider,
|
||||
{
|
||||
fn name(&self) -> &str {
|
||||
LlmProvider::name(self)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
LlmProvider::list_models(self).await
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
LlmProvider::send_prompt(self, request).await
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let stream = LlmProvider::stream_prompt(self, request).await?;
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<()> {
|
||||
LlmProvider::health_check(self).await
|
||||
}
|
||||
|
||||
fn config_schema(&self) -> serde_json::Value {
|
||||
LlmProvider::config_schema(self)
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn Any + Send + Sync) {
|
||||
LlmProvider::as_any(self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime configuration for a provider instance.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ProviderConfig {
|
||||
/// Whether this provider should be activated.
|
||||
#[serde(default = "ProviderConfig::default_enabled")]
|
||||
pub enabled: bool,
|
||||
/// Provider type identifier used to resolve implementations.
|
||||
#[serde(default)]
|
||||
pub provider_type: String,
|
||||
/// Base URL for API calls.
|
||||
#[serde(default)]
|
||||
pub base_url: Option<String>,
|
||||
/// API key or token material.
|
||||
#[serde(default)]
|
||||
pub api_key: Option<String>,
|
||||
/// Environment variable holding the API key.
|
||||
#[serde(default)]
|
||||
pub api_key_env: Option<String>,
|
||||
/// Additional provider-specific configuration.
|
||||
#[serde(flatten)]
|
||||
pub extra: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderConfig {
|
||||
const fn default_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Merge the current configuration with overrides from `other`.
|
||||
pub fn merge_from(&mut self, mut other: ProviderConfig) {
|
||||
self.enabled = other.enabled;
|
||||
|
||||
if !other.provider_type.is_empty() {
|
||||
self.provider_type = other.provider_type;
|
||||
}
|
||||
|
||||
if let Some(base_url) = other.base_url.take() {
|
||||
self.base_url = Some(base_url);
|
||||
}
|
||||
|
||||
if let Some(api_key) = other.api_key.take() {
|
||||
self.api_key = Some(api_key);
|
||||
}
|
||||
|
||||
if let Some(api_key_env) = other.api_key_env.take() {
|
||||
self.api_key_env = Some(api_key_env);
|
||||
}
|
||||
|
||||
if !other.extra.is_empty() {
|
||||
self.extra.extend(other.extra);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Static registry of providers available to the application.
|
||||
pub struct ProviderRegistry {
|
||||
providers: HashMap<String, Arc<dyn Provider>>,
|
||||
}
|
||||
|
||||
impl ProviderRegistry {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
providers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||
self.register_arc(Arc::new(provider));
|
||||
}
|
||||
|
||||
pub fn register_arc(&mut self, provider: Arc<dyn Provider>) {
|
||||
let name = provider.name().to_string();
|
||||
self.providers.insert(name, provider);
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Provider>> {
|
||||
self.providers.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn list_providers(&self) -> Vec<String> {
|
||||
self.providers.keys().cloned().collect()
|
||||
}
|
||||
|
||||
pub async fn list_all_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let mut all_models = Vec::new();
|
||||
|
||||
for provider in self.providers.values() {
|
||||
match provider.list_models().await {
|
||||
Ok(mut models) => all_models.append(&mut models),
|
||||
Err(_) => {
|
||||
// Ignore failing providers and continue.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(all_models)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Test utilities for constructing mock providers.
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use super::*;
|
||||
use futures::stream;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
/// Simple provider stub that always returns the same response.
|
||||
pub struct MockProvider {
|
||||
name: String,
|
||||
response: ChatResponse,
|
||||
call_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
pub fn new(name: impl Into<String>, response: ChatResponse) -> Self {
|
||||
Self {
|
||||
name: name.into(),
|
||||
response,
|
||||
call_count: AtomicUsize::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn call_count(&self) -> usize {
|
||||
self.call_count.load(Ordering::Relaxed)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MockProvider {
|
||||
fn default() -> Self {
|
||||
Self::new(
|
||||
"mock-provider",
|
||||
ChatResponse {
|
||||
message: Message::assistant("mock response".to_string()),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl LlmProvider for MockProvider {
|
||||
type Stream = stream::Iter<std::vec::IntoIter<Result<ChatResponse>>>;
|
||||
|
||||
type ListModelsFuture<'a>
|
||||
= futures::future::Ready<Result<Vec<ModelInfo>>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type SendPromptFuture<'a>
|
||||
= futures::future::Ready<Result<ChatResponse>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type StreamPromptFuture<'a>
|
||||
= futures::future::Ready<Result<Self::Stream>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
type HealthCheckFuture<'a>
|
||||
= futures::future::Ready<Result<()>>
|
||||
where
|
||||
Self: 'a;
|
||||
|
||||
fn name(&self) -> &str {
|
||||
&self.name
|
||||
}
|
||||
|
||||
fn list_models(&self) -> Self::ListModelsFuture<'_> {
|
||||
futures::future::ready(Ok(vec![]))
|
||||
}
|
||||
|
||||
fn send_prompt(&self, _request: ChatRequest) -> Self::SendPromptFuture<'_> {
|
||||
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||
futures::future::ready(Ok(self.response.clone()))
|
||||
}
|
||||
|
||||
fn stream_prompt(&self, _request: ChatRequest) -> Self::StreamPromptFuture<'_> {
|
||||
self.call_count.fetch_add(1, Ordering::Relaxed);
|
||||
let response = self.response.clone();
|
||||
futures::future::ready(Ok(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
fn health_check(&self) -> Self::HealthCheckFuture<'_> {
|
||||
futures::future::ready(Ok(()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,188 +0,0 @@
|
||||
use crate::Result;
|
||||
use crate::mode::Mode;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
use async_trait::async_trait;
|
||||
pub use client::McpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
pub mod client;
|
||||
pub mod factory;
|
||||
pub mod failover;
|
||||
pub mod permission;
|
||||
pub mod presets;
|
||||
pub mod protocol;
|
||||
pub mod remote_client;
|
||||
|
||||
/// Descriptor for a tool exposed over MCP
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolDescriptor {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: Value,
|
||||
pub requires_network: bool,
|
||||
pub requires_filesystem: Vec<String>,
|
||||
}
|
||||
|
||||
/// Invocation payload for a tool call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Result returned by a tool invocation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolResponse {
|
||||
pub name: String,
|
||||
pub success: bool,
|
||||
pub output: Value,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub duration_ms: u128,
|
||||
}
|
||||
|
||||
/// Thin MCP server facade over the tool registry
|
||||
pub struct McpServer {
|
||||
registry: Arc<ToolRegistry>,
|
||||
validator: Arc<SchemaValidator>,
|
||||
mode: Arc<tokio::sync::RwLock<Mode>>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
pub fn new(registry: Arc<ToolRegistry>, validator: Arc<SchemaValidator>) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
validator,
|
||||
mode: Arc::new(tokio::sync::RwLock::new(Mode::default())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the current operating mode
|
||||
pub async fn set_mode(&self, mode: Mode) {
|
||||
*self.mode.write().await = mode;
|
||||
}
|
||||
|
||||
/// Get the current operating mode
|
||||
pub async fn get_mode(&self) -> Mode {
|
||||
*self.mode.read().await
|
||||
}
|
||||
|
||||
/// Enumerate the registered tools as MCP descriptors
|
||||
pub async fn list_tools(&self) -> Vec<McpToolDescriptor> {
|
||||
let mode = self.get_mode().await;
|
||||
let available_tools = self.registry.available_tools(mode).await;
|
||||
|
||||
self.registry
|
||||
.all()
|
||||
.into_iter()
|
||||
.filter(|tool| available_tools.contains(&tool.name().to_string()))
|
||||
.map(|tool| McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.schema(),
|
||||
requires_network: tool.requires_network(),
|
||||
requires_filesystem: tool.requires_filesystem(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Execute a tool call after validating inputs against the registered schema
|
||||
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
self.validator.validate(&call.name, &call.arguments)?;
|
||||
let mode = self.get_mode().await;
|
||||
let result = self
|
||||
.registry
|
||||
.execute(&call.name, call.arguments, mode)
|
||||
.await?;
|
||||
Ok(McpToolResponse {
|
||||
name: call.name,
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
metadata: result.metadata,
|
||||
duration_ms: duration_to_millis(result.duration),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn duration_to_millis(duration: Duration) -> u128 {
|
||||
duration.as_secs() as u128 * 1_000 + u128::from(duration.subsec_millis())
|
||||
}
|
||||
|
||||
pub struct LocalMcpClient {
|
||||
server: McpServer,
|
||||
}
|
||||
|
||||
impl LocalMcpClient {
|
||||
pub fn new(registry: Arc<ToolRegistry>, validator: Arc<SchemaValidator>) -> Self {
|
||||
Self {
|
||||
server: McpServer::new(registry, validator),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the current operating mode
|
||||
pub async fn set_mode(&self, mode: Mode) {
|
||||
self.server.set_mode(mode).await;
|
||||
}
|
||||
|
||||
/// Get the current operating mode
|
||||
pub async fn get_mode(&self) -> Mode {
|
||||
self.server.get_mode().await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for LocalMcpClient {
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
Ok(self.server.list_tools().await)
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
self.server.call_tool(call).await
|
||||
}
|
||||
|
||||
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||
self.server.set_mode(mode).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub mod test_utils {
|
||||
use super::*;
|
||||
|
||||
/// Mock MCP client for testing
|
||||
#[derive(Default)]
|
||||
pub struct MockMcpClient;
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for MockMcpClient {
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
Ok(vec![McpToolDescriptor {
|
||||
name: "mock_tool".to_string(),
|
||||
description: "A mock tool for testing".to_string(),
|
||||
input_schema: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
}])
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
Ok(McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: serde_json::json!({"result": "mock result"}),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 10,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::{Result, mode::Mode};
|
||||
use async_trait::async_trait;
|
||||
|
||||
/// Trait for a client that can interact with an MCP server
|
||||
#[async_trait]
|
||||
pub trait McpClient: Send + Sync {
|
||||
/// List the tools available on the server
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>>;
|
||||
|
||||
/// Call a tool on the server
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse>;
|
||||
|
||||
/// Update the server with the active operating mode.
|
||||
async fn set_mode(&self, _mode: Mode) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
// Re-export the concrete implementation that supports stdio and HTTP transports.
|
||||
pub use super::remote_client::RemoteMcpClient;
|
||||
@@ -1,194 +0,0 @@
|
||||
/// MCP Client Factory
|
||||
///
|
||||
/// Provides a unified interface for creating MCP clients based on configuration.
|
||||
/// Supports switching between local (in-process) and remote (STDIO) execution modes.
|
||||
use super::client::McpClient;
|
||||
use super::{
|
||||
LocalMcpClient,
|
||||
remote_client::{McpRuntimeSecrets, RemoteMcpClient},
|
||||
};
|
||||
use crate::config::{Config, McpMode};
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
use crate::{Error, Result};
|
||||
use log::{info, warn};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Factory for creating MCP clients based on configuration
|
||||
pub struct McpClientFactory {
|
||||
config: Arc<Config>,
|
||||
registry: Arc<ToolRegistry>,
|
||||
validator: Arc<SchemaValidator>,
|
||||
}
|
||||
|
||||
impl McpClientFactory {
|
||||
pub fn new(
|
||||
config: Arc<Config>,
|
||||
registry: Arc<ToolRegistry>,
|
||||
validator: Arc<SchemaValidator>,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
registry,
|
||||
validator,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an MCP client based on the current configuration.
|
||||
pub async fn create(&self) -> Result<Box<dyn McpClient>> {
|
||||
self.create_with_secrets(None).await
|
||||
}
|
||||
|
||||
/// Create an MCP client using optional runtime secrets (OAuth tokens, env overrides).
|
||||
pub async fn create_with_secrets(
|
||||
&self,
|
||||
runtime: Option<McpRuntimeSecrets>,
|
||||
) -> Result<Box<dyn McpClient>> {
|
||||
match self.config.mcp.mode {
|
||||
McpMode::Disabled => Err(Error::Config(
|
||||
"MCP mode is set to 'disabled'; tooling cannot function in this configuration."
|
||||
.to_string(),
|
||||
)),
|
||||
McpMode::LocalOnly | McpMode::Legacy => {
|
||||
if matches!(self.config.mcp.mode, McpMode::Legacy) {
|
||||
warn!("Using deprecated MCP legacy mode; consider switching to 'local_only'.");
|
||||
}
|
||||
Ok(Box::new(LocalMcpClient::new(
|
||||
self.registry.clone(),
|
||||
self.validator.clone(),
|
||||
)))
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let server_cfg = self.config.effective_mcp_servers().first().ok_or_else(|| {
|
||||
Error::Config(
|
||||
"MCP mode 'remote_only' requires at least one entry in [[mcp_servers]]"
|
||||
.to_string(),
|
||||
)
|
||||
})?;
|
||||
|
||||
RemoteMcpClient::new_with_runtime(server_cfg, runtime)
|
||||
.await
|
||||
.map(|client| Box::new(client) as Box<dyn McpClient>)
|
||||
.map_err(|e| {
|
||||
Error::Config(format!(
|
||||
"Failed to start remote MCP client '{}': {e}",
|
||||
server_cfg.name
|
||||
))
|
||||
})
|
||||
}
|
||||
McpMode::RemotePreferred => {
|
||||
if let Some(server_cfg) = self.config.effective_mcp_servers().first() {
|
||||
match RemoteMcpClient::new_with_runtime(server_cfg, runtime.clone()).await {
|
||||
Ok(client) => {
|
||||
info!(
|
||||
"Connected to remote MCP server '{}' via {} transport.",
|
||||
server_cfg.name, server_cfg.transport
|
||||
);
|
||||
Ok(Box::new(client) as Box<dyn McpClient>)
|
||||
}
|
||||
Err(e) if self.config.mcp.allow_fallback => {
|
||||
warn!(
|
||||
"Failed to start remote MCP client '{}': {}. Falling back to local tooling.",
|
||||
server_cfg.name, e
|
||||
);
|
||||
Ok(Box::new(LocalMcpClient::new(
|
||||
self.registry.clone(),
|
||||
self.validator.clone(),
|
||||
)))
|
||||
}
|
||||
Err(e) => Err(Error::Config(format!(
|
||||
"Failed to start remote MCP client '{}': {e}. To allow fallback, set [mcp].allow_fallback = true.",
|
||||
server_cfg.name
|
||||
))),
|
||||
}
|
||||
} else {
|
||||
warn!("No MCP servers configured; using local MCP tooling.");
|
||||
Ok(Box::new(LocalMcpClient::new(
|
||||
self.registry.clone(),
|
||||
self.validator.clone(),
|
||||
)))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if remote MCP mode is available
|
||||
pub async fn is_remote_available() -> bool {
|
||||
RemoteMcpClient::new().await.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::Error;
|
||||
use crate::config::McpServerConfig;
|
||||
|
||||
fn build_factory(config: Config) -> McpClientFactory {
|
||||
let ui = Arc::new(crate::ui::NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new(config.clone())),
|
||||
ui,
|
||||
));
|
||||
let validator = Arc::new(SchemaValidator::new());
|
||||
|
||||
McpClientFactory::new(Arc::new(config), registry, validator)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_factory_creates_local_client_when_no_servers_configured() {
|
||||
let mut config = Config::default();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
|
||||
// Should create without error and fall back to local client
|
||||
let result = factory.create().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remote_only_without_servers_errors() {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::RemoteOnly;
|
||||
config.mcp_servers.clear();
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create().await;
|
||||
assert!(matches!(result, Err(Error::Config(_))));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_remote_preferred_without_fallback_propagates_remote_error() {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = false;
|
||||
config.mcp_servers = vec![McpServerConfig {
|
||||
name: "invalid".to_string(),
|
||||
command: "nonexistent-mcp-server-binary".to_string(),
|
||||
args: Vec::new(),
|
||||
transport: "stdio".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
rpc_timeout_secs: None,
|
||||
}];
|
||||
config.refresh_mcp_servers(None).unwrap();
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create().await;
|
||||
assert!(
|
||||
matches!(result, Err(Error::Config(message)) if message.contains("Failed to start remote MCP client"))
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_legacy_mode_uses_local_client() {
|
||||
let mut config = Config::default();
|
||||
config.mcp.mode = McpMode::Legacy;
|
||||
|
||||
let factory = build_factory(config);
|
||||
let result = factory.create().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
@@ -1,324 +0,0 @@
|
||||
//! Failover and redundancy support for MCP clients
|
||||
//!
|
||||
//! Provides automatic failover between multiple MCP servers with:
|
||||
//! - Health checking
|
||||
//! - Priority-based selection
|
||||
//! - Automatic retry with exponential backoff
|
||||
//! - Circuit breaker pattern
|
||||
|
||||
use super::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::{Error, Result};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
/// Server health status
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum ServerHealth {
|
||||
/// Server is healthy and available
|
||||
Healthy,
|
||||
/// Server is experiencing issues but may recover
|
||||
Degraded { since: Instant },
|
||||
/// Server is down
|
||||
Down { since: Instant },
|
||||
}
|
||||
|
||||
/// Server configuration with priority
|
||||
#[derive(Clone)]
|
||||
pub struct ServerEntry {
|
||||
/// Name for logging
|
||||
pub name: String,
|
||||
/// MCP client instance
|
||||
pub client: Arc<dyn McpClient>,
|
||||
/// Priority (lower = higher priority)
|
||||
pub priority: u32,
|
||||
/// Health status
|
||||
health: Arc<RwLock<ServerHealth>>,
|
||||
/// Last health check time
|
||||
last_check: Arc<RwLock<Option<Instant>>>,
|
||||
}
|
||||
|
||||
impl ServerEntry {
|
||||
pub fn new(name: String, client: Arc<dyn McpClient>, priority: u32) -> Self {
|
||||
Self {
|
||||
name,
|
||||
client,
|
||||
priority,
|
||||
health: Arc::new(RwLock::new(ServerHealth::Healthy)),
|
||||
last_check: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if server is available
|
||||
pub async fn is_available(&self) -> bool {
|
||||
let health = self.health.read().await;
|
||||
matches!(*health, ServerHealth::Healthy)
|
||||
}
|
||||
|
||||
/// Mark server as healthy
|
||||
pub async fn mark_healthy(&self) {
|
||||
let mut health = self.health.write().await;
|
||||
*health = ServerHealth::Healthy;
|
||||
let mut last_check = self.last_check.write().await;
|
||||
*last_check = Some(Instant::now());
|
||||
}
|
||||
|
||||
/// Mark server as down
|
||||
pub async fn mark_down(&self) {
|
||||
let mut health = self.health.write().await;
|
||||
*health = ServerHealth::Down {
|
||||
since: Instant::now(),
|
||||
};
|
||||
}
|
||||
|
||||
/// Mark server as degraded
|
||||
pub async fn mark_degraded(&self) {
|
||||
let mut health = self.health.write().await;
|
||||
if matches!(*health, ServerHealth::Healthy) {
|
||||
*health = ServerHealth::Degraded {
|
||||
since: Instant::now(),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current health status
|
||||
pub async fn get_health(&self) -> ServerHealth {
|
||||
self.health.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
||||
/// Failover configuration
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FailoverConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: usize,
|
||||
/// Base retry delay (will be exponentially increased)
|
||||
pub base_retry_delay: Duration,
|
||||
/// Health check interval
|
||||
pub health_check_interval: Duration,
|
||||
/// Timeout for health checks
|
||||
pub health_check_timeout: Duration,
|
||||
/// Circuit breaker threshold (failures before opening circuit)
|
||||
pub circuit_breaker_threshold: usize,
|
||||
}
|
||||
|
||||
impl Default for FailoverConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
base_retry_delay: Duration::from_millis(100),
|
||||
health_check_interval: Duration::from_secs(30),
|
||||
health_check_timeout: Duration::from_secs(5),
|
||||
circuit_breaker_threshold: 5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// MCP client with failover support
|
||||
pub struct FailoverMcpClient {
|
||||
servers: Arc<RwLock<Vec<ServerEntry>>>,
|
||||
config: FailoverConfig,
|
||||
consecutive_failures: Arc<RwLock<usize>>,
|
||||
}
|
||||
|
||||
impl FailoverMcpClient {
|
||||
/// Create a new failover client with multiple servers
|
||||
pub fn new(servers: Vec<ServerEntry>, config: FailoverConfig) -> Self {
|
||||
// Sort servers by priority
|
||||
let mut sorted_servers = servers;
|
||||
sorted_servers.sort_by_key(|s| s.priority);
|
||||
|
||||
Self {
|
||||
servers: Arc::new(RwLock::new(sorted_servers)),
|
||||
config,
|
||||
consecutive_failures: Arc::new(RwLock::new(0)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with default configuration
|
||||
pub fn with_servers(servers: Vec<ServerEntry>) -> Self {
|
||||
Self::new(servers, FailoverConfig::default())
|
||||
}
|
||||
|
||||
/// Get the first available server
|
||||
async fn get_available_server(&self) -> Option<ServerEntry> {
|
||||
let servers = self.servers.read().await;
|
||||
for server in servers.iter() {
|
||||
if server.is_available().await {
|
||||
return Some(server.clone());
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Execute an operation with automatic failover
|
||||
async fn with_failover<F, T>(&self, operation: F) -> Result<T>
|
||||
where
|
||||
F: Fn(Arc<dyn McpClient>) -> futures::future::BoxFuture<'static, Result<T>>,
|
||||
T: Send + 'static,
|
||||
{
|
||||
let mut attempt = 0;
|
||||
let mut last_error = None;
|
||||
|
||||
while attempt < self.config.max_retries {
|
||||
// Get available server
|
||||
let server = match self.get_available_server().await {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
// No healthy servers, try all servers anyway
|
||||
let servers = self.servers.read().await;
|
||||
if let Some(first) = servers.first() {
|
||||
first.clone()
|
||||
} else {
|
||||
return Err(Error::Network("No servers configured".to_string()));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute operation
|
||||
match operation(server.client.clone()).await {
|
||||
Ok(result) => {
|
||||
server.mark_healthy().await;
|
||||
let mut failures = self.consecutive_failures.write().await;
|
||||
*failures = 0;
|
||||
return Ok(result);
|
||||
}
|
||||
Err(e) => {
|
||||
log::warn!("Server '{}' failed: {}", server.name, e);
|
||||
server.mark_degraded().await;
|
||||
last_error = Some(e);
|
||||
|
||||
let mut failures = self.consecutive_failures.write().await;
|
||||
*failures += 1;
|
||||
|
||||
if *failures >= self.config.circuit_breaker_threshold {
|
||||
server.mark_down().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Exponential backoff
|
||||
if attempt < self.config.max_retries - 1 {
|
||||
let delay = self.config.base_retry_delay * 2_u32.pow(attempt as u32);
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
}
|
||||
|
||||
Err(last_error.unwrap_or_else(|| Error::Network("All servers failed".to_string())))
|
||||
}
|
||||
|
||||
/// Perform health check on all servers
|
||||
pub async fn health_check_all(&self) {
|
||||
let servers = self.servers.read().await;
|
||||
for server in servers.iter() {
|
||||
let client = server.client.clone();
|
||||
let server_clone = server.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
match tokio::time::timeout(
|
||||
Duration::from_secs(5),
|
||||
// Use a simple list_tools call as health check
|
||||
async { client.list_tools().await },
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(Ok(_)) => server_clone.mark_healthy().await,
|
||||
Ok(Err(e)) => {
|
||||
log::warn!("Health check failed for '{}': {}", server_clone.name, e);
|
||||
server_clone.mark_down().await;
|
||||
}
|
||||
Err(_) => {
|
||||
log::warn!("Health check timeout for '{}'", server_clone.name);
|
||||
server_clone.mark_down().await;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Start background health checking
|
||||
pub fn start_health_checks(&self) -> tokio::task::JoinHandle<()> {
|
||||
let client = self.clone_ref();
|
||||
let interval = self.config.health_check_interval;
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut interval_timer = tokio::time::interval(interval);
|
||||
loop {
|
||||
interval_timer.tick().await;
|
||||
client.health_check_all().await;
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Clone the client (returns new handle to same underlying data)
|
||||
fn clone_ref(&self) -> Self {
|
||||
Self {
|
||||
servers: self.servers.clone(),
|
||||
config: self.config.clone(),
|
||||
consecutive_failures: self.consecutive_failures.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get status of all servers
|
||||
pub async fn get_server_status(&self) -> Vec<(String, ServerHealth)> {
|
||||
let servers = self.servers.read().await;
|
||||
let mut status = Vec::new();
|
||||
for server in servers.iter() {
|
||||
status.push((server.name.clone(), server.get_health().await));
|
||||
}
|
||||
status
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for FailoverMcpClient {
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
self.with_failover(|client| Box::pin(async move { client.list_tools().await }))
|
||||
.await
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
self.with_failover(|client| {
|
||||
let call_clone = call.clone();
|
||||
Box::pin(async move { client.call_tool(call_clone).await })
|
||||
})
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_entry_health() {
|
||||
use crate::mcp::remote_client::RemoteMcpClient;
|
||||
|
||||
// This would need a mock client in practice
|
||||
// Just demonstrating the API
|
||||
let config = crate::config::McpServerConfig {
|
||||
name: "test".to_string(),
|
||||
command: "test".to_string(),
|
||||
args: vec![],
|
||||
transport: "http".to_string(),
|
||||
env: std::collections::HashMap::new(),
|
||||
oauth: None,
|
||||
rpc_timeout_secs: None,
|
||||
};
|
||||
|
||||
if let Ok(client) = RemoteMcpClient::new_with_config(&config).await {
|
||||
let entry = ServerEntry::new("test".to_string(), Arc::new(client), 1);
|
||||
|
||||
assert!(entry.is_available().await);
|
||||
|
||||
entry.mark_down().await;
|
||||
assert!(!entry.is_available().await);
|
||||
|
||||
entry.mark_healthy().await;
|
||||
assert!(entry.is_available().await);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,229 +0,0 @@
|
||||
/// Permission and Safety Layer for MCP
|
||||
///
|
||||
/// This module provides runtime enforcement of security policies for tool execution.
|
||||
/// It wraps MCP clients to filter/whitelist tool calls, log invocations, and prompt for consent.
|
||||
use super::client::McpClient;
|
||||
use super::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use crate::tools::{WEB_SEARCH_TOOL_NAME, tool_name_matches};
|
||||
use crate::{Error, Result};
|
||||
use crate::{config::Config, mode::Mode};
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Callback for requesting user consent for dangerous operations
|
||||
pub type ConsentCallback = Arc<dyn Fn(&str, &McpToolCall) -> bool + Send + Sync>;
|
||||
|
||||
/// Callback for logging tool invocations
|
||||
pub type LogCallback = Arc<dyn Fn(&str, &McpToolCall, &Result<McpToolResponse>) + Send + Sync>;
|
||||
|
||||
/// Permission-enforcing wrapper around an MCP client
|
||||
pub struct PermissionLayer {
|
||||
inner: Box<dyn McpClient>,
|
||||
config: Arc<Config>,
|
||||
consent_callback: Option<ConsentCallback>,
|
||||
log_callback: Option<LogCallback>,
|
||||
allowed_tools: HashSet<String>,
|
||||
}
|
||||
|
||||
impl PermissionLayer {
|
||||
/// Create a new permission layer wrapping the given client
|
||||
pub fn new(inner: Box<dyn McpClient>, config: Arc<Config>) -> Self {
|
||||
let allowed_tools = config.security.allowed_tools.iter().cloned().collect();
|
||||
|
||||
Self {
|
||||
inner,
|
||||
config,
|
||||
consent_callback: None,
|
||||
log_callback: None,
|
||||
allowed_tools,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set a callback for requesting user consent
|
||||
pub fn with_consent_callback(mut self, callback: ConsentCallback) -> Self {
|
||||
self.consent_callback = Some(callback);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set a callback for logging tool invocations
|
||||
pub fn with_log_callback(mut self, callback: LogCallback) -> Self {
|
||||
self.log_callback = Some(callback);
|
||||
self
|
||||
}
|
||||
|
||||
/// Check if a tool requires dangerous filesystem operations
|
||||
fn requires_dangerous_filesystem(&self, tool_name: &str) -> bool {
|
||||
matches!(
|
||||
tool_name,
|
||||
"resources_write" | "resources_delete" | "file_write" | "file_delete"
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if a tool is allowed by security policy
|
||||
fn is_tool_allowed(&self, tool_descriptor: &McpToolDescriptor) -> bool {
|
||||
// Check if tool requires filesystem access
|
||||
for fs_perm in &tool_descriptor.requires_filesystem {
|
||||
if !self.allowed_tools.contains(fs_perm) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check if tool requires network access
|
||||
if tool_descriptor.requires_network
|
||||
&& !self
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Request user consent for a tool call
|
||||
fn request_consent(&self, tool_name: &str, call: &McpToolCall) -> bool {
|
||||
if let Some(ref callback) = self.consent_callback {
|
||||
callback(tool_name, call)
|
||||
} else {
|
||||
// If no callback is set, deny dangerous operations by default
|
||||
!self.requires_dangerous_filesystem(tool_name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Log a tool invocation
|
||||
fn log_invocation(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
call: &McpToolCall,
|
||||
result: &Result<McpToolResponse>,
|
||||
) {
|
||||
if let Some(ref callback) = self.log_callback {
|
||||
callback(tool_name, call, result);
|
||||
} else {
|
||||
// Default logging to stderr
|
||||
match result {
|
||||
Ok(resp) => {
|
||||
eprintln!(
|
||||
"[MCP] Tool '{}' executed successfully ({}ms)",
|
||||
tool_name, resp.duration_ms
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("[MCP] Tool '{}' failed: {}", tool_name, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl McpClient for PermissionLayer {
|
||||
async fn list_tools(&self) -> Result<Vec<McpToolDescriptor>> {
|
||||
let tools = self.inner.list_tools().await?;
|
||||
// Filter tools based on security policy
|
||||
Ok(tools
|
||||
.into_iter()
|
||||
.filter(|tool| self.is_tool_allowed(tool))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
// Check if tool requires consent
|
||||
if self.requires_dangerous_filesystem(&call.name)
|
||||
&& self.config.privacy.require_consent_per_session
|
||||
&& !self.request_consent(&call.name, &call)
|
||||
{
|
||||
let result = Err(Error::PermissionDenied(format!(
|
||||
"User denied consent for tool '{}'",
|
||||
call.name
|
||||
)));
|
||||
self.log_invocation(&call.name, &call, &result);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Execute the tool call
|
||||
let result = self.inner.call_tool(call.clone()).await;
|
||||
|
||||
// Log the invocation
|
||||
self.log_invocation(&call.name, &call, &result);
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
async fn set_mode(&self, mode: Mode) -> Result<()> {
|
||||
self.inner.set_mode(mode).await
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::mcp::LocalMcpClient;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::ui::NoOpUiController;
|
||||
use crate::validation::SchemaValidator;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_permission_layer_filters_dangerous_tools() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
));
|
||||
let validator = Arc::new(SchemaValidator::new());
|
||||
let client = Box::new(LocalMcpClient::new(registry, validator));
|
||||
|
||||
let mut config_mut = (*config).clone();
|
||||
// Disallow file operations
|
||||
config_mut.security.allowed_tools = vec![WEB_SEARCH_TOOL_NAME.to_string()];
|
||||
|
||||
let permission_layer = PermissionLayer::new(client, Arc::new(config_mut));
|
||||
|
||||
let tools = permission_layer.list_tools().await.unwrap();
|
||||
|
||||
// Should not include file_write or file_delete tools
|
||||
assert!(!tools.iter().any(|t| t.name.contains("write")));
|
||||
assert!(!tools.iter().any(|t| t.name.contains("delete")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_consent_callback_is_invoked() {
|
||||
let config = Arc::new(Config::default());
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
let registry = Arc::new(ToolRegistry::new(
|
||||
Arc::new(tokio::sync::Mutex::new((*config).clone())),
|
||||
ui,
|
||||
));
|
||||
let validator = Arc::new(SchemaValidator::new());
|
||||
let client = Box::new(LocalMcpClient::new(registry, validator));
|
||||
|
||||
let consent_called = Arc::new(AtomicBool::new(false));
|
||||
let consent_called_clone = consent_called.clone();
|
||||
|
||||
let consent_callback: ConsentCallback = Arc::new(move |_tool, _call| {
|
||||
consent_called_clone.store(true, Ordering::SeqCst);
|
||||
false // Deny
|
||||
});
|
||||
|
||||
let mut config_mut = (*config).clone();
|
||||
config_mut.privacy.require_consent_per_session = true;
|
||||
|
||||
let permission_layer = PermissionLayer::new(client, Arc::new(config_mut))
|
||||
.with_consent_callback(consent_callback);
|
||||
|
||||
let call = McpToolCall {
|
||||
name: "resources_write".to_string(),
|
||||
arguments: serde_json::json!({"path": "test.txt", "content": "hello"}),
|
||||
};
|
||||
|
||||
let result = permission_layer.call_tool(call).await;
|
||||
|
||||
assert!(consent_called.load(Ordering::SeqCst));
|
||||
assert!(result.is_err());
|
||||
}
|
||||
}
|
||||
@@ -1,446 +0,0 @@
|
||||
//! Reference MCP connector presets shared across leading client ecosystems.
|
||||
//!
|
||||
//! These definitions intentionally avoid vendor-specific naming while capturing
|
||||
//! the union of commonly shipped servers: local tooling, automation, retrieval,
|
||||
//! observability, and productivity integrations.
|
||||
|
||||
use crate::config::McpServerConfig;
|
||||
use crate::tools::tool_identifier_violation;
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::str::FromStr;
|
||||
|
||||
/// High-level preset tiers exposed to CLI/TUI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum PresetTier {
|
||||
Standard,
|
||||
Extended,
|
||||
Full,
|
||||
}
|
||||
|
||||
impl PresetTier {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
PresetTier::Standard => "standard",
|
||||
PresetTier::Extended => "extended",
|
||||
PresetTier::Full => "full",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all() -> &'static [PresetTier] {
|
||||
&[PresetTier::Standard, PresetTier::Extended, PresetTier::Full]
|
||||
}
|
||||
|
||||
fn description(self) -> &'static str {
|
||||
match self {
|
||||
PresetTier::Standard => {
|
||||
"Core local tooling (filesystem, terminal, git, browser, fetch, python, notebook)."
|
||||
}
|
||||
PresetTier::Extended => {
|
||||
"Standard + retrieval/automation connectors (search, scraping, planning)."
|
||||
}
|
||||
PresetTier::Full => {
|
||||
"Extended + SaaS integrations (observability, productivity, data stores)."
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PresetTier {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
let normalized = s.trim().to_ascii_lowercase();
|
||||
match normalized.as_str() {
|
||||
"standard" | "std" => Ok(PresetTier::Standard),
|
||||
"extended" | "ext" => Ok(PresetTier::Extended),
|
||||
"full" | "all" => Ok(PresetTier::Full),
|
||||
other => Err(anyhow!(format!(
|
||||
"Unknown preset tier '{other}'. Expected one of: standard, extended, full."
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Lightweight description of an MCP connector entry.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct PresetConnector {
|
||||
pub name: &'static str,
|
||||
pub command: &'static str,
|
||||
pub args: &'static [&'static str],
|
||||
pub env: &'static [(&'static str, &'static str)],
|
||||
pub description: &'static str,
|
||||
pub capabilities: &'static [&'static str],
|
||||
}
|
||||
|
||||
impl PresetConnector {
|
||||
pub fn to_config(&self) -> McpServerConfig {
|
||||
if let Some(reason) = tool_identifier_violation(self.name) {
|
||||
panic!("Invalid preset connector '{}': {reason}", self.name);
|
||||
}
|
||||
|
||||
McpServerConfig {
|
||||
name: self.name.to_string(),
|
||||
command: self.command.to_string(),
|
||||
args: self.args.iter().map(|arg| arg.to_string()).collect(),
|
||||
transport: "stdio".to_string(),
|
||||
env: self
|
||||
.env
|
||||
.iter()
|
||||
.map(|(k, v)| ((*k).to_string(), (*v).to_string()))
|
||||
.collect::<HashMap<_, _>>(),
|
||||
oauth: None,
|
||||
rpc_timeout_secs: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const STANDARD_CONNECTORS: &[PresetConnector] = &[
|
||||
PresetConnector {
|
||||
name: "filesystem",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
env: &[],
|
||||
description: "Mount local project directories for read/write operations.",
|
||||
capabilities: &["filesystem", "local"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "terminal",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-shell"],
|
||||
env: &[],
|
||||
description: "Execute shell commands within a sandboxed environment.",
|
||||
capabilities: &["shell", "local"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "git",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-git"],
|
||||
env: &[],
|
||||
description: "Interact with Git repositories for status, diffs, commits.",
|
||||
capabilities: &["git", "local"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "browser",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-browser"],
|
||||
env: &[],
|
||||
description: "Perform scripted browser automation via headless Chromium.",
|
||||
capabilities: &["browser", "automation"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "fetch",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-fetch"],
|
||||
env: &[],
|
||||
description: "Issue structured HTTP requests for REST/JSON APIs.",
|
||||
capabilities: &["network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "python",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-python"],
|
||||
env: &[],
|
||||
description: "Run Python snippets in an isolated interpreter.",
|
||||
capabilities: &["compute", "python"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "notebook",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-notebook"],
|
||||
env: &[],
|
||||
description: "Evaluate notebook cells and manage Jupyter sessions.",
|
||||
capabilities: &["compute", "notebook"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "sequential_thinking",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-sequential-thinking"],
|
||||
env: &[],
|
||||
description: "Structured reasoning helper with planning support.",
|
||||
capabilities: &["planning"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "puppeteer",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-puppeteer"],
|
||||
env: &[],
|
||||
description: "Full-browser automation via Puppeteer.",
|
||||
capabilities: &["browser", "automation"],
|
||||
},
|
||||
];
|
||||
|
||||
const EXTENDED_CONNECTORS: &[PresetConnector] = &[
|
||||
PresetConnector {
|
||||
name: "brave_search",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-brave-search"],
|
||||
env: &[("BRAVE_API_KEY", "")],
|
||||
description: "Search the web using Brave Search APIs.",
|
||||
capabilities: &["search", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "tavily",
|
||||
command: "npx",
|
||||
args: &["-y", "@tavily/mcp-server"],
|
||||
env: &[("TAVILY_API_KEY", "")],
|
||||
description: "General-purpose research with Tavily's search/reasoning API.",
|
||||
capabilities: &["search", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "perplexity",
|
||||
command: "npx",
|
||||
args: &["-y", "@perplexity-ai/mcp-server"],
|
||||
env: &[("PPLX_API_KEY", "")],
|
||||
description: "Ask questions against Perplexity's API.",
|
||||
capabilities: &["qa", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "firecrawl",
|
||||
command: "npx",
|
||||
args: &["-y", "@firecrawl/mcp-server"],
|
||||
env: &[("FIRECRAWL_TOKEN", "")],
|
||||
description: "Crawl and scrape webpages for summarisation.",
|
||||
capabilities: &["scrape", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "memory_bank",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-memory"],
|
||||
env: &[],
|
||||
description: "Persist structured memories for long-lived tasks.",
|
||||
capabilities: &["memory"],
|
||||
},
|
||||
];
|
||||
|
||||
const FULL_CONNECTORS: &[PresetConnector] = &[
|
||||
PresetConnector {
|
||||
name: "sentry",
|
||||
command: "npx",
|
||||
args: &["-y", "@sentry/mcp-server"],
|
||||
env: &[("SENTRY_AUTH_TOKEN", "")],
|
||||
description: "Query issues and alerts from Sentry.",
|
||||
capabilities: &["observability", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "notion",
|
||||
command: "npx",
|
||||
args: &["-y", "@notionhq/mcp-server"],
|
||||
env: &[("NOTION_API_KEY", "")],
|
||||
description: "Access Notion databases and pages.",
|
||||
capabilities: &["productivity", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "slack",
|
||||
command: "npx",
|
||||
args: &["-y", "@slack/mcp-server"],
|
||||
env: &[("SLACK_BOT_TOKEN", "")],
|
||||
description: "Send messages and search channels in Slack.",
|
||||
capabilities: &["communication", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "stripe",
|
||||
command: "npx",
|
||||
args: &["-y", "@stripe/mcp-server"],
|
||||
env: &[("STRIPE_API_KEY", "")],
|
||||
description: "Inspect customers, invoices, and payment intents.",
|
||||
capabilities: &["payments", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "google_drive",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-google-drive"],
|
||||
env: &[("GOOGLE_DRIVE_CREDENTIALS", "")],
|
||||
description: "Browse and fetch Google Drive documents.",
|
||||
capabilities: &["storage", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "zapier",
|
||||
command: "npx",
|
||||
args: &["-y", "@zapier/mcp-server"],
|
||||
env: &[("ZAPIER_NLA_API_KEY", "")],
|
||||
description: "Trigger Zapier actions and workflows.",
|
||||
capabilities: &["automation", "network"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "postgresql",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-postgresql"],
|
||||
env: &[],
|
||||
description: "Run SQL against a PostgreSQL database.",
|
||||
capabilities: &["database"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "sqlite",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-sqlite"],
|
||||
env: &[],
|
||||
description: "Run SQL against local SQLite databases.",
|
||||
capabilities: &["database"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "redis",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-redis"],
|
||||
env: &[],
|
||||
description: "Inspect Redis keys and run commands.",
|
||||
capabilities: &["cache", "database"],
|
||||
},
|
||||
PresetConnector {
|
||||
name: "qdrant",
|
||||
command: "npx",
|
||||
args: &["-y", "@modelcontextprotocol/server-qdrant"],
|
||||
env: &[],
|
||||
description: "Interact with Qdrant vector collections.",
|
||||
capabilities: &["vector", "database"],
|
||||
},
|
||||
];
|
||||
|
||||
fn connectors_for_tier_internal(tier: PresetTier) -> Vec<PresetConnector> {
|
||||
let mut result = Vec::new();
|
||||
result.extend_from_slice(STANDARD_CONNECTORS);
|
||||
if matches!(tier, PresetTier::Extended | PresetTier::Full) {
|
||||
result.extend_from_slice(EXTENDED_CONNECTORS);
|
||||
}
|
||||
if matches!(tier, PresetTier::Full) {
|
||||
result.extend_from_slice(FULL_CONNECTORS);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// Return connectors for the given tier (including lower tiers).
|
||||
pub fn connectors_for_tier(tier: PresetTier) -> Vec<PresetConnector> {
|
||||
connectors_for_tier_internal(tier)
|
||||
}
|
||||
|
||||
/// Describe the preset tiers for help output.
|
||||
pub fn tier_descriptions() -> Vec<(PresetTier, &'static str)> {
|
||||
PresetTier::all()
|
||||
.iter()
|
||||
.map(|tier| (*tier, tier.description()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Details about changes performed when applying a preset.
|
||||
#[derive(Debug)]
|
||||
pub struct PresetApplyReport {
|
||||
pub tier: PresetTier,
|
||||
pub added: Vec<String>,
|
||||
pub updated: Vec<String>,
|
||||
pub removed: Vec<String>,
|
||||
}
|
||||
|
||||
impl PresetApplyReport {
|
||||
fn new(tier: PresetTier) -> Self {
|
||||
Self {
|
||||
tier,
|
||||
added: Vec::new(),
|
||||
updated: Vec::new(),
|
||||
removed: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Details discovered during audit.
|
||||
#[derive(Debug)]
|
||||
pub struct PresetAuditReport {
|
||||
pub tier: PresetTier,
|
||||
pub missing: Vec<PresetConnector>,
|
||||
pub mismatched: Vec<(PresetConnector, McpServerConfig)>,
|
||||
pub extra: Vec<McpServerConfig>,
|
||||
}
|
||||
|
||||
impl PresetAuditReport {
|
||||
fn new(tier: PresetTier) -> Self {
|
||||
Self {
|
||||
tier,
|
||||
missing: Vec::new(),
|
||||
mismatched: Vec::new(),
|
||||
extra: Vec::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply the requested preset to the given configuration.
|
||||
pub fn apply_preset(
|
||||
config: &mut crate::config::Config,
|
||||
tier: PresetTier,
|
||||
prune: bool,
|
||||
) -> Result<PresetApplyReport> {
|
||||
let mut report = PresetApplyReport::new(tier);
|
||||
|
||||
let connectors = connectors_for_tier_internal(tier);
|
||||
let expected_names: HashSet<&str> = connectors.iter().map(|c| c.name).collect();
|
||||
|
||||
if prune {
|
||||
config.mcp_servers.retain(|existing| {
|
||||
if expected_names.contains(existing.name.as_str()) {
|
||||
true
|
||||
} else {
|
||||
report.removed.push(existing.name.clone());
|
||||
false
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
for connector in connectors {
|
||||
match config
|
||||
.mcp_servers
|
||||
.iter_mut()
|
||||
.find(|srv| srv.name == connector.name)
|
||||
{
|
||||
Some(existing) => {
|
||||
let candidate = connector.to_config();
|
||||
if existing.command != candidate.command
|
||||
|| existing.args != candidate.args
|
||||
|| existing.env != candidate.env
|
||||
{
|
||||
*existing = candidate;
|
||||
report.updated.push(connector.name.to_string());
|
||||
}
|
||||
}
|
||||
None => {
|
||||
config.mcp_servers.push(connector.to_config());
|
||||
report.added.push(connector.name.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(report)
|
||||
}
|
||||
|
||||
/// Audit the configuration against a preset without mutating it.
|
||||
pub fn audit_preset(config: &crate::config::Config, tier: PresetTier) -> PresetAuditReport {
|
||||
let mut report = PresetAuditReport::new(tier);
|
||||
|
||||
let connectors = connectors_for_tier_internal(tier);
|
||||
let expected: HashMap<&str, &PresetConnector> =
|
||||
connectors.iter().map(|c| (c.name, c)).collect();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
for server in &config.mcp_servers {
|
||||
if let Some(expected_connector) = expected.get(server.name.as_str()) {
|
||||
seen.insert(server.name.as_str());
|
||||
let expected_config = expected_connector.to_config();
|
||||
if expected_config.command != server.command
|
||||
|| expected_config.args != server.args
|
||||
|| expected_config.env != server.env
|
||||
{
|
||||
report
|
||||
.mismatched
|
||||
.push((**expected_connector, server.clone()));
|
||||
}
|
||||
} else {
|
||||
report.extra.push(server.clone());
|
||||
}
|
||||
}
|
||||
|
||||
for connector in connectors {
|
||||
if !seen.contains(connector.name) {
|
||||
report.missing.push(connector);
|
||||
}
|
||||
}
|
||||
|
||||
report
|
||||
}
|
||||
@@ -1,389 +0,0 @@
|
||||
/// MCP Protocol Definitions
|
||||
///
|
||||
/// This module defines the JSON-RPC protocol contracts for the Model Context Protocol (MCP).
|
||||
/// It includes request/response schemas, error codes, and versioning semantics.
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// MCP Protocol version - uses semantic versioning
|
||||
pub const PROTOCOL_VERSION: &str = "1.0.0";
|
||||
|
||||
/// JSON-RPC version constant
|
||||
pub const JSONRPC_VERSION: &str = "2.0";
|
||||
|
||||
// ============================================================================
|
||||
// Error Codes and Handling
|
||||
// ============================================================================
|
||||
|
||||
/// Standard JSON-RPC error codes following the spec
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct ErrorCode(pub i64);
|
||||
|
||||
impl ErrorCode {
|
||||
// Standard JSON-RPC 2.0 errors
|
||||
pub const PARSE_ERROR: Self = Self(-32700);
|
||||
pub const INVALID_REQUEST: Self = Self(-32600);
|
||||
pub const METHOD_NOT_FOUND: Self = Self(-32601);
|
||||
pub const INVALID_PARAMS: Self = Self(-32602);
|
||||
pub const INTERNAL_ERROR: Self = Self(-32603);
|
||||
|
||||
// MCP-specific errors (range -32000 to -32099)
|
||||
pub const TOOL_NOT_FOUND: Self = Self(-32000);
|
||||
pub const TOOL_EXECUTION_FAILED: Self = Self(-32001);
|
||||
pub const PERMISSION_DENIED: Self = Self(-32002);
|
||||
pub const RESOURCE_NOT_FOUND: Self = Self(-32003);
|
||||
pub const TIMEOUT: Self = Self(-32004);
|
||||
pub const VALIDATION_ERROR: Self = Self(-32005);
|
||||
pub const PATH_TRAVERSAL: Self = Self(-32006);
|
||||
pub const RATE_LIMIT_EXCEEDED: Self = Self(-32007);
|
||||
}
|
||||
|
||||
/// Structured error response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RpcError {
|
||||
pub code: i64,
|
||||
pub message: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub data: Option<Value>,
|
||||
}
|
||||
|
||||
impl RpcError {
|
||||
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
code: code.0,
|
||||
message: message.into(),
|
||||
data: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_data(mut self, data: Value) -> Self {
|
||||
self.data = Some(data);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn parse_error(message: impl Into<String>) -> Self {
|
||||
Self::new(ErrorCode::PARSE_ERROR, message)
|
||||
}
|
||||
|
||||
pub fn invalid_request(message: impl Into<String>) -> Self {
|
||||
Self::new(ErrorCode::INVALID_REQUEST, message)
|
||||
}
|
||||
|
||||
pub fn method_not_found(method: &str) -> Self {
|
||||
Self::new(
|
||||
ErrorCode::METHOD_NOT_FOUND,
|
||||
format!("Method not found: {}", method),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn invalid_params(message: impl Into<String>) -> Self {
|
||||
Self::new(ErrorCode::INVALID_PARAMS, message)
|
||||
}
|
||||
|
||||
pub fn internal_error(message: impl Into<String>) -> Self {
|
||||
Self::new(ErrorCode::INTERNAL_ERROR, message)
|
||||
}
|
||||
|
||||
pub fn tool_not_found(tool_name: &str) -> Self {
|
||||
Self::new(
|
||||
ErrorCode::TOOL_NOT_FOUND,
|
||||
format!("Tool not found: {}", tool_name),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn permission_denied(message: impl Into<String>) -> Self {
|
||||
Self::new(ErrorCode::PERMISSION_DENIED, message)
|
||||
}
|
||||
|
||||
pub fn path_traversal() -> Self {
|
||||
Self::new(ErrorCode::PATH_TRAVERSAL, "Path traversal attempt detected")
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Request/Response Structures
|
||||
// ============================================================================
|
||||
|
||||
/// JSON-RPC request structure
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RpcRequest {
|
||||
pub jsonrpc: String,
|
||||
pub id: RequestId,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
impl RpcRequest {
|
||||
pub fn new(id: RequestId, method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id,
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC response structure (success)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RpcResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: RequestId,
|
||||
pub result: Value,
|
||||
}
|
||||
|
||||
impl RpcResponse {
|
||||
pub fn new(id: RequestId, result: Value) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id,
|
||||
result,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON-RPC error response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RpcErrorResponse {
|
||||
pub jsonrpc: String,
|
||||
pub id: RequestId,
|
||||
pub error: RpcError,
|
||||
}
|
||||
|
||||
impl RpcErrorResponse {
|
||||
pub fn new(id: RequestId, error: RpcError) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
id,
|
||||
error,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON‑RPC notification (no id). Used for streaming partial results.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RpcNotification {
|
||||
pub jsonrpc: String,
|
||||
pub method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub params: Option<Value>,
|
||||
}
|
||||
|
||||
impl RpcNotification {
|
||||
pub fn new(method: impl Into<String>, params: Option<Value>) -> Self {
|
||||
Self {
|
||||
jsonrpc: JSONRPC_VERSION.to_string(),
|
||||
method: method.into(),
|
||||
params,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Request ID can be string, number, or null
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
|
||||
#[serde(untagged)]
|
||||
pub enum RequestId {
|
||||
Number(u64),
|
||||
String(String),
|
||||
}
|
||||
|
||||
impl From<u64> for RequestId {
|
||||
fn from(n: u64) -> Self {
|
||||
Self::Number(n)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for RequestId {
|
||||
fn from(s: String) -> Self {
|
||||
Self::String(s)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// MCP Method Names
|
||||
// ============================================================================
|
||||
|
||||
/// Standard MCP methods
|
||||
pub mod methods {
|
||||
pub const INITIALIZE: &str = "initialize";
|
||||
pub const TOOLS_LIST: &str = "tools/list";
|
||||
pub const TOOLS_CALL: &str = "tools/call";
|
||||
pub const RESOURCES_LIST: &str = "resources_list";
|
||||
pub const RESOURCES_GET: &str = "resources_get";
|
||||
pub const RESOURCES_WRITE: &str = "resources_write";
|
||||
pub const RESOURCES_DELETE: &str = "resources_delete";
|
||||
pub const MODELS_LIST: &str = "models/list";
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Initialization Protocol
|
||||
// ============================================================================
|
||||
|
||||
/// Initialize request parameters
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InitializeParams {
|
||||
pub protocol_version: String,
|
||||
pub client_info: ClientInfo,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub capabilities: Option<ClientCapabilities>,
|
||||
}
|
||||
|
||||
impl Default for InitializeParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
client_info: ClientInfo {
|
||||
name: "owlen".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Client information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClientInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// Client capabilities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ClientCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub supports_streaming: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub supports_cancellation: Option<bool>,
|
||||
}
|
||||
|
||||
/// Initialize response
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct InitializeResult {
|
||||
pub protocol_version: String,
|
||||
pub server_info: ServerInfo,
|
||||
pub capabilities: ServerCapabilities,
|
||||
}
|
||||
|
||||
/// Server information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ServerInfo {
|
||||
pub name: String,
|
||||
pub version: String,
|
||||
}
|
||||
|
||||
/// Server capabilities
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub supports_tools: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub supports_resources: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub supports_streaming: Option<bool>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Call Protocol
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for tools/list
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ToolsListParams {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub filter: Option<String>,
|
||||
}
|
||||
|
||||
/// Parameters for tools/call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolsCallParams {
|
||||
pub name: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<Value>,
|
||||
}
|
||||
|
||||
/// Result of tools/call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolsCallResult {
|
||||
pub success: bool,
|
||||
pub output: Value,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<Value>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Resource Protocol
|
||||
// ============================================================================
|
||||
|
||||
/// Parameters for resources/list
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourcesListParams {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// Parameters for resources/get
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourcesGetParams {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
/// Parameters for resources/write
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourcesWriteParams {
|
||||
pub path: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// Parameters for resources/delete
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResourcesDeleteParams {
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Versioning and Compatibility
|
||||
// ============================================================================
|
||||
|
||||
/// Check if a protocol version is compatible
|
||||
pub fn is_compatible(client_version: &str, server_version: &str) -> bool {
|
||||
// For now, simple exact match on major version
|
||||
let client_major = client_version.split('.').next().unwrap_or("0");
|
||||
let server_major = server_version.split('.').next().unwrap_or("0");
|
||||
client_major == server_major
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_error_codes() {
|
||||
let err = RpcError::tool_not_found("test_tool");
|
||||
assert_eq!(err.code, ErrorCode::TOOL_NOT_FOUND.0);
|
||||
assert!(err.message.contains("test_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_version_compatibility() {
|
||||
assert!(is_compatible("1.0.0", "1.0.0"));
|
||||
assert!(is_compatible("1.0.0", "1.1.0"));
|
||||
assert!(is_compatible("1.2.5", "1.0.0"));
|
||||
assert!(!is_compatible("1.0.0", "2.0.0"));
|
||||
assert!(!is_compatible("2.0.0", "1.0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_serialization() {
|
||||
let req = RpcRequest::new(
|
||||
RequestId::Number(1),
|
||||
"tools/call",
|
||||
Some(serde_json::json!({"name": "test"})),
|
||||
);
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"jsonrpc\":\"2.0\""));
|
||||
assert!(json.contains("\"method\":\"tools/call\""));
|
||||
}
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,189 +0,0 @@
|
||||
//! Operating modes for Owlen
|
||||
//!
|
||||
//! Defines the different modes in which Owlen can operate and their associated
|
||||
//! tool availability policies.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
|
||||
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name};
|
||||
|
||||
/// Operating mode for Owlen
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Mode {
|
||||
/// Chat mode - limited tool access, safe for general conversation
|
||||
#[default]
|
||||
Chat,
|
||||
/// Code mode - full tool access for development tasks
|
||||
Code,
|
||||
}
|
||||
|
||||
impl Mode {
|
||||
/// Get the display name for this mode
|
||||
pub fn display_name(&self) -> &'static str {
|
||||
match self {
|
||||
Mode::Chat => "chat",
|
||||
Mode::Code => "code",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Mode {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.display_name())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for Mode {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"chat" => Ok(Mode::Chat),
|
||||
"code" => Ok(Mode::Code),
|
||||
_ => Err(format!(
|
||||
"Invalid mode: '{}'. Valid modes are 'chat' or 'code'",
|
||||
s
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Configuration for tool availability in different modes
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModeConfig {
|
||||
/// Tools allowed in chat mode
|
||||
#[serde(default = "ModeConfig::default_chat_tools")]
|
||||
pub chat: ModeToolConfig,
|
||||
/// Tools allowed in code mode
|
||||
#[serde(default = "ModeConfig::default_code_tools")]
|
||||
pub code: ModeToolConfig,
|
||||
}
|
||||
|
||||
impl Default for ModeConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
chat: Self::default_chat_tools(),
|
||||
code: Self::default_code_tools(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ModeConfig {
|
||||
fn default_chat_tools() -> ModeToolConfig {
|
||||
ModeToolConfig {
|
||||
allowed_tools: vec![WEB_SEARCH_TOOL_NAME.to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn default_code_tools() -> ModeToolConfig {
|
||||
ModeToolConfig {
|
||||
allowed_tools: vec!["*".to_string()], // All tools allowed
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is allowed in the given mode
|
||||
pub fn is_tool_allowed(&self, mode: Mode, tool_name: &str) -> bool {
|
||||
let config = match mode {
|
||||
Mode::Chat => &self.chat,
|
||||
Mode::Code => &self.code,
|
||||
};
|
||||
|
||||
config.is_tool_allowed(tool_name)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool configuration for a specific mode
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModeToolConfig {
|
||||
/// List of allowed tools. Use "*" to allow all tools.
|
||||
pub allowed_tools: Vec<String>,
|
||||
}
|
||||
|
||||
impl ModeToolConfig {
|
||||
/// Check if a tool is allowed in this mode
|
||||
pub fn is_tool_allowed(&self, tool_name: &str) -> bool {
|
||||
// Check for wildcard
|
||||
if self.allowed_tools.iter().any(|t| t == "*") {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Check if tool is explicitly listed
|
||||
let target = canonical_tool_name(tool_name);
|
||||
self.allowed_tools
|
||||
.iter()
|
||||
.any(|t| canonical_tool_name(t) == target)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mode_display() {
|
||||
assert_eq!(Mode::Chat.to_string(), "chat");
|
||||
assert_eq!(Mode::Code.to_string(), "code");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mode_from_str() {
|
||||
assert_eq!("chat".parse::<Mode>(), Ok(Mode::Chat));
|
||||
assert_eq!("code".parse::<Mode>(), Ok(Mode::Code));
|
||||
assert_eq!("CHAT".parse::<Mode>(), Ok(Mode::Chat));
|
||||
assert_eq!("CODE".parse::<Mode>(), Ok(Mode::Code));
|
||||
assert!("invalid".parse::<Mode>().is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_mode() {
|
||||
assert_eq!(Mode::default(), Mode::Chat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chat_mode_restrictions() {
|
||||
let config = ModeConfig::default();
|
||||
|
||||
// Web search should be allowed in chat mode
|
||||
assert!(config.is_tool_allowed(Mode::Chat, WEB_SEARCH_TOOL_NAME));
|
||||
assert!(config.is_tool_allowed(Mode::Chat, "web_search"));
|
||||
|
||||
// Code exec should not be allowed in chat mode
|
||||
assert!(!config.is_tool_allowed(Mode::Chat, "code_exec"));
|
||||
assert!(!config.is_tool_allowed(Mode::Chat, "file_write"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_code_mode_allows_all() {
|
||||
let config = ModeConfig::default();
|
||||
|
||||
// All tools should be allowed in code mode
|
||||
assert!(config.is_tool_allowed(Mode::Code, WEB_SEARCH_TOOL_NAME));
|
||||
assert!(config.is_tool_allowed(Mode::Code, "web_search"));
|
||||
assert!(config.is_tool_allowed(Mode::Code, "code_exec"));
|
||||
assert!(config.is_tool_allowed(Mode::Code, "file_write"));
|
||||
assert!(config.is_tool_allowed(Mode::Code, "anything"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wildcard_tool_config() {
|
||||
let config = ModeToolConfig {
|
||||
allowed_tools: vec!["*".to_string()],
|
||||
};
|
||||
|
||||
assert!(config.is_tool_allowed("any_tool"));
|
||||
assert!(config.is_tool_allowed("another_tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_explicit_tool_list() {
|
||||
let config = ModeToolConfig {
|
||||
allowed_tools: vec!["tool1".to_string(), "tool2".to_string()],
|
||||
};
|
||||
|
||||
assert!(config.is_tool_allowed("tool1"));
|
||||
assert!(config.is_tool_allowed("tool2"));
|
||||
assert!(!config.is_tool_allowed("tool3"));
|
||||
}
|
||||
}
|
||||
@@ -1,209 +0,0 @@
|
||||
pub mod details;
|
||||
|
||||
pub use details::{DetailedModelInfo, ModelInfoRetrievalError};
|
||||
|
||||
use crate::Result;
|
||||
use crate::types::ModelInfo;
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ModelCache {
|
||||
models: Vec<ModelInfo>,
|
||||
last_refresh: Option<Instant>,
|
||||
}
|
||||
|
||||
/// Caches model listings for improved selection performance
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelManager {
|
||||
cache: Arc<RwLock<ModelCache>>,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl ModelManager {
|
||||
/// Create a new manager with the desired cache TTL
|
||||
pub fn new(ttl: Duration) -> Self {
|
||||
Self {
|
||||
cache: Arc::new(RwLock::new(ModelCache::default())),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get cached models, refreshing via the provided fetcher when stale. Returns the up-to-date model list.
|
||||
pub async fn get_or_refresh<F, Fut>(
|
||||
&self,
|
||||
force_refresh: bool,
|
||||
fetcher: F,
|
||||
) -> Result<Vec<ModelInfo>>
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<Vec<ModelInfo>>>,
|
||||
{
|
||||
if let (false, Some(models)) = (force_refresh, self.cached_if_fresh().await) {
|
||||
return Ok(models);
|
||||
}
|
||||
|
||||
let models = fetcher().await?;
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.models = models.clone();
|
||||
cache.last_refresh = Some(Instant::now());
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Return cached models without refreshing
|
||||
pub async fn cached(&self) -> Vec<ModelInfo> {
|
||||
self.cache.read().await.models.clone()
|
||||
}
|
||||
|
||||
/// Drop cached models, forcing next call to refresh
|
||||
pub async fn invalidate(&self) {
|
||||
let mut cache = self.cache.write().await;
|
||||
cache.models.clear();
|
||||
cache.last_refresh = None;
|
||||
}
|
||||
|
||||
/// Select a model by id or name from the cache
|
||||
pub async fn select(&self, identifier: &str) -> Option<ModelInfo> {
|
||||
let cache = self.cache.read().await;
|
||||
cache
|
||||
.models
|
||||
.iter()
|
||||
.find(|m| m.id == identifier || m.name == identifier)
|
||||
.cloned()
|
||||
}
|
||||
|
||||
async fn cached_if_fresh(&self) -> Option<Vec<ModelInfo>> {
|
||||
let cache = self.cache.read().await;
|
||||
let fresh = matches!(cache.last_refresh, Some(ts) if ts.elapsed() < self.ttl);
|
||||
if fresh && !cache.models.is_empty() {
|
||||
Some(cache.models.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Debug)]
|
||||
struct ModelDetailsCacheInner {
|
||||
by_key: HashMap<String, DetailedModelInfo>,
|
||||
name_to_key: HashMap<String, String>,
|
||||
fetched_at: HashMap<String, Instant>,
|
||||
}
|
||||
|
||||
/// Cache for rich model details, indexed by digest when available.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ModelDetailsCache {
|
||||
inner: Arc<RwLock<ModelDetailsCacheInner>>,
|
||||
ttl: Duration,
|
||||
}
|
||||
|
||||
impl ModelDetailsCache {
|
||||
/// Create a new details cache with the provided TTL.
|
||||
pub fn new(ttl: Duration) -> Self {
|
||||
Self {
|
||||
inner: Arc::new(RwLock::new(ModelDetailsCacheInner::default())),
|
||||
ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Try to read cached details for the provided model name.
|
||||
pub async fn get(&self, name: &str) -> Option<DetailedModelInfo> {
|
||||
let mut inner = self.inner.write().await;
|
||||
let key = inner.name_to_key.get(name).cloned()?;
|
||||
let stale = inner
|
||||
.fetched_at
|
||||
.get(&key)
|
||||
.is_some_and(|ts| ts.elapsed() >= self.ttl);
|
||||
if stale {
|
||||
inner.by_key.remove(&key);
|
||||
inner.name_to_key.remove(name);
|
||||
inner.fetched_at.remove(&key);
|
||||
return None;
|
||||
}
|
||||
inner.by_key.get(&key).cloned()
|
||||
}
|
||||
|
||||
/// Cache the provided details, overwriting existing entries.
|
||||
pub async fn insert(&self, info: DetailedModelInfo) {
|
||||
let key = info.digest.clone().unwrap_or_else(|| info.name.clone());
|
||||
let mut inner = self.inner.write().await;
|
||||
|
||||
// Remove prior mappings for this model name (possibly different digest).
|
||||
if let Some(previous_key) = inner.name_to_key.get(&info.name).cloned()
|
||||
&& previous_key != key
|
||||
{
|
||||
inner.by_key.remove(&previous_key);
|
||||
inner.fetched_at.remove(&previous_key);
|
||||
}
|
||||
|
||||
inner.fetched_at.insert(key.clone(), Instant::now());
|
||||
inner.name_to_key.insert(info.name.clone(), key.clone());
|
||||
inner.by_key.insert(key, info);
|
||||
}
|
||||
|
||||
/// Remove a specific model from the cache.
|
||||
pub async fn invalidate(&self, name: &str) {
|
||||
let mut inner = self.inner.write().await;
|
||||
if let Some(key) = inner.name_to_key.remove(name) {
|
||||
inner.by_key.remove(&key);
|
||||
inner.fetched_at.remove(&key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Clear the entire cache.
|
||||
pub async fn invalidate_all(&self) {
|
||||
let mut inner = self.inner.write().await;
|
||||
inner.by_key.clear();
|
||||
inner.name_to_key.clear();
|
||||
inner.fetched_at.clear();
|
||||
}
|
||||
|
||||
/// Return all cached values regardless of freshness.
|
||||
pub async fn cached(&self) -> Vec<DetailedModelInfo> {
|
||||
let inner = self.inner.read().await;
|
||||
inner.by_key.values().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::time::Duration;
|
||||
use tokio::time::sleep;
|
||||
|
||||
fn sample_details(name: &str) -> DetailedModelInfo {
|
||||
DetailedModelInfo {
|
||||
name: name.to_string(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_returns_cached_entry() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_millis(50));
|
||||
let info = sample_details("llama");
|
||||
cache.insert(info.clone()).await;
|
||||
let cached = cache.get("llama").await;
|
||||
assert!(cached.is_some());
|
||||
assert_eq!(cached.unwrap().name, "llama");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_expires_based_on_ttl() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_millis(10));
|
||||
cache.insert(sample_details("phi")).await;
|
||||
sleep(Duration::from_millis(30)).await;
|
||||
assert!(cache.get("phi").await.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_details_cache_invalidate_removes_entry() {
|
||||
let cache = ModelDetailsCache::new(Duration::from_secs(1));
|
||||
cache.insert(sample_details("mistral")).await;
|
||||
cache.invalidate("mistral").await;
|
||||
assert!(cache.get("mistral").await.is_none());
|
||||
}
|
||||
}
|
||||
@@ -1,105 +0,0 @@
|
||||
//! Detailed model metadata for provider inspection features.
|
||||
//!
|
||||
//! These types capture richer information about locally available models
|
||||
//! than the lightweight [`crate::types::ModelInfo`] listing and back the
|
||||
//! higher-level inspection UI exposed in the Owlen TUI.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Rich metadata about an Ollama model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct DetailedModelInfo {
|
||||
/// Canonical model name (including tag).
|
||||
pub name: String,
|
||||
/// Reported architecture or model format.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub architecture: Option<String>,
|
||||
/// Human-readable parameter / quantisation summary.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameters: Option<String>,
|
||||
/// Context window length, if provided.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context_length: Option<u64>,
|
||||
/// Embedding vector length for embedding-capable models.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedding_length: Option<u64>,
|
||||
/// Quantisation level (e.g., Q4_0, Q5_K_M).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub quantization: Option<String>,
|
||||
/// Primary family identifier (e.g., llama3).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub family: Option<String>,
|
||||
/// Additional family tags reported by Ollama.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub families: Vec<String>,
|
||||
/// Verbose parameter size description (e.g., 70B parameters).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameter_size: Option<String>,
|
||||
/// Default prompt template packaged with the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub template: Option<String>,
|
||||
/// Default system prompt packaged with the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<String>,
|
||||
/// License string provided by the model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub license: Option<String>,
|
||||
/// Raw modelfile contents (if available).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modelfile: Option<String>,
|
||||
/// Modification timestamp (ISO-8601) if reported.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modified_at: Option<String>,
|
||||
/// Approximate model size in bytes.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub size: Option<u64>,
|
||||
/// Digest / checksum used by Ollama (sha256).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub digest: Option<String>,
|
||||
}
|
||||
|
||||
impl DetailedModelInfo {
|
||||
/// Convenience helper that normalises empty strings to `None`.
|
||||
pub fn with_normalised_strings(mut self) -> Self {
|
||||
if self.architecture.as_ref().is_some_and(String::is_empty) {
|
||||
self.architecture = None;
|
||||
}
|
||||
if self.parameters.as_ref().is_some_and(String::is_empty) {
|
||||
self.parameters = None;
|
||||
}
|
||||
if self.quantization.as_ref().is_some_and(String::is_empty) {
|
||||
self.quantization = None;
|
||||
}
|
||||
if self.family.as_ref().is_some_and(String::is_empty) {
|
||||
self.family = None;
|
||||
}
|
||||
if self.parameter_size.as_ref().is_some_and(String::is_empty) {
|
||||
self.parameter_size = None;
|
||||
}
|
||||
if self.template.as_ref().is_some_and(String::is_empty) {
|
||||
self.template = None;
|
||||
}
|
||||
if self.system.as_ref().is_some_and(String::is_empty) {
|
||||
self.system = None;
|
||||
}
|
||||
if self.license.as_ref().is_some_and(String::is_empty) {
|
||||
self.license = None;
|
||||
}
|
||||
if self.modelfile.as_ref().is_some_and(String::is_empty) {
|
||||
self.modelfile = None;
|
||||
}
|
||||
if self.digest.as_ref().is_some_and(String::is_empty) {
|
||||
self.digest = None;
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Error payload returned when model inspection fails for a specific model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfoRetrievalError {
|
||||
/// Model that failed to resolve.
|
||||
pub model_name: String,
|
||||
/// Human-readable description of the failure.
|
||||
pub error_message: String,
|
||||
}
|
||||
@@ -1,507 +0,0 @@
|
||||
use std::time::Duration as StdDuration;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{Error, Result, config::McpOAuthConfig};
|
||||
|
||||
/// Persisted OAuth token set for MCP servers and providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
||||
pub struct OAuthToken {
|
||||
/// Bearer access token returned by the authorization server.
|
||||
pub access_token: String,
|
||||
/// Optional refresh token if the provider issues one.
|
||||
#[serde(default)]
|
||||
pub refresh_token: Option<String>,
|
||||
/// Absolute UTC expiration timestamp for the access token.
|
||||
#[serde(default)]
|
||||
pub expires_at: Option<DateTime<Utc>>,
|
||||
/// Optional space-delimited scope string supplied by the provider.
|
||||
#[serde(default)]
|
||||
pub scope: Option<String>,
|
||||
/// Token type reported by the provider (typically `Bearer`).
|
||||
#[serde(default)]
|
||||
pub token_type: Option<String>,
|
||||
}
|
||||
|
||||
impl OAuthToken {
|
||||
/// Returns `true` if the access token has expired at the provided instant.
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
||||
}
|
||||
|
||||
/// Returns `true` if the token will expire within the supplied duration window.
|
||||
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
||||
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
||||
}
|
||||
}
|
||||
|
||||
/// Active device-authorization session details returned by the authorization server.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeviceAuthorization {
|
||||
pub device_code: String,
|
||||
pub user_code: String,
|
||||
pub verification_uri: String,
|
||||
pub verification_uri_complete: Option<String>,
|
||||
pub expires_at: DateTime<Utc>,
|
||||
pub interval: StdDuration,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
impl DeviceAuthorization {
|
||||
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
||||
now >= self.expires_at
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of polling the token endpoint during a device-authorization flow.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DevicePollState {
|
||||
Pending { retry_in: StdDuration },
|
||||
Complete(OAuthToken),
|
||||
}
|
||||
|
||||
pub struct OAuthClient {
|
||||
http: Client,
|
||||
config: McpOAuthConfig,
|
||||
}
|
||||
|
||||
impl OAuthClient {
|
||||
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
||||
let http = Client::builder()
|
||||
.user_agent("OwlenOAuth/1.0")
|
||||
.build()
|
||||
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
||||
Ok(Self { http, config })
|
||||
}
|
||||
|
||||
fn scope_value(&self) -> Option<String> {
|
||||
if self.config.scopes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(self.config.scopes.join(" "))
|
||||
}
|
||||
}
|
||||
|
||||
fn token_request_base(&self) -> Vec<(String, String)> {
|
||||
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
||||
if let Some(secret) = &self.config.client_secret {
|
||||
params.push(("client_secret".to_string(), secret.clone()));
|
||||
}
|
||||
params
|
||||
}
|
||||
|
||||
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
||||
let device_url = self
|
||||
.config
|
||||
.device_authorization_url
|
||||
.as_ref()
|
||||
.ok_or_else(|| {
|
||||
Error::Config("Device authorization endpoint is not configured.".to_string())
|
||||
})?;
|
||||
|
||||
let mut params = self.token_request_base();
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(device_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("start device authorization", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let payload = response
|
||||
.json::<DeviceAuthorizationResponse>()
|
||||
.await
|
||||
.map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse device authorization response (status {status}): {err}"
|
||||
))
|
||||
})?;
|
||||
|
||||
let expires_at =
|
||||
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
||||
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
||||
|
||||
Ok(DeviceAuthorization {
|
||||
device_code: payload.device_code,
|
||||
user_code: payload.user_code,
|
||||
verification_uri: payload.verification_uri,
|
||||
verification_uri_complete: payload.verification_uri_complete,
|
||||
expires_at,
|
||||
interval,
|
||||
message: payload.message,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
||||
params.push(("device_code".to_string(), auth.device_code.clone()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("poll device token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read token response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth token response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
||||
payload,
|
||||
)));
|
||||
}
|
||||
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
|
||||
match error.error.as_str() {
|
||||
"authorization_pending" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval,
|
||||
}),
|
||||
"slow_down" => Ok(DevicePollState::Pending {
|
||||
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
||||
}),
|
||||
"access_denied" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"User declined authorization".to_string()
|
||||
})))
|
||||
}
|
||||
"expired_token" | "expired_device_code" => {
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
"Device authorization expired".to_string()
|
||||
})))
|
||||
}
|
||||
other => Err(Error::Auth(
|
||||
error
|
||||
.error_description
|
||||
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
||||
let mut params = self.token_request_base();
|
||||
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
||||
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
||||
if let Some(scope) = self.scope_value() {
|
||||
params.push(("scope".to_string(), scope));
|
||||
}
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&self.config.token_url)
|
||||
.form(¶ms)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
||||
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(|err| map_http_error("read refresh response", err))?;
|
||||
|
||||
if status.is_success() {
|
||||
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
||||
Error::Auth(format!(
|
||||
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
||||
))
|
||||
})?;
|
||||
Ok(oauth_token_from_response(payload))
|
||||
} else {
|
||||
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
||||
OAuthErrorResponse {
|
||||
error: "unknown_error".to_string(),
|
||||
error_description: Some(text.clone()),
|
||||
}
|
||||
});
|
||||
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
||||
format!("OAuth token refresh failed: {}", error.error)
|
||||
})))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceAuthorizationResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default)]
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
#[serde(default)]
|
||||
interval: Option<u64>,
|
||||
#[serde(default)]
|
||||
message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenResponse {
|
||||
access_token: String,
|
||||
#[serde(default)]
|
||||
refresh_token: Option<String>,
|
||||
#[serde(default)]
|
||||
expires_in: Option<u64>,
|
||||
#[serde(default)]
|
||||
scope: Option<String>,
|
||||
#[serde(default)]
|
||||
token_type: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OAuthErrorResponse {
|
||||
error: String,
|
||||
#[serde(default)]
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
||||
let expires_at = payload
|
||||
.expires_in
|
||||
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
||||
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
||||
|
||||
OAuthToken {
|
||||
access_token: payload.access_token,
|
||||
refresh_token: payload.refresh_token,
|
||||
expires_at,
|
||||
scope: payload.scope,
|
||||
token_type: payload.token_type,
|
||||
}
|
||||
}
|
||||
|
||||
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
||||
if err.is_timeout() {
|
||||
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
||||
} else if err.is_connect() {
|
||||
Error::Network(format!("OAuth {action} connection error: {err}"))
|
||||
} else {
|
||||
Error::Network(format!("OAuth {action} request failed: {err}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use httpmock::prelude::*;
|
||||
use serde_json::json;
|
||||
|
||||
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
||||
McpOAuthConfig {
|
||||
client_id: "test-client".to_string(),
|
||||
client_secret: None,
|
||||
authorize_url: server.url("/authorize"),
|
||||
token_url: server.url("/token"),
|
||||
device_authorization_url: Some(server.url("/device")),
|
||||
redirect_url: None,
|
||||
scopes: vec!["repo".to_string(), "user".to_string()],
|
||||
token_env: None,
|
||||
header: None,
|
||||
header_prefix: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn sample_device_authorization() -> DeviceAuthorization {
|
||||
DeviceAuthorization {
|
||||
device_code: "device-123".to_string(),
|
||||
user_code: "ABCD-EFGH".to_string(),
|
||||
verification_uri: "https://example.test/activate".to_string(),
|
||||
verification_uri_complete: Some(
|
||||
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
||||
),
|
||||
expires_at: Utc::now() + Duration::minutes(10),
|
||||
interval: StdDuration::from_secs(5),
|
||||
message: Some("Open the verification URL and enter the code.".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_device_authorization_returns_payload() {
|
||||
let server = MockServer::start_async().await;
|
||||
let device_mock = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/device");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"device_code": "device-123",
|
||||
"user_code": "ABCD-EFGH",
|
||||
"verification_uri": "https://example.test/activate",
|
||||
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
||||
"expires_in": 600,
|
||||
"interval": 7,
|
||||
"message": "Open the verification URL and enter the code."
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let client = OAuthClient::new(config_for(&server)).expect("client");
|
||||
let auth = client
|
||||
.start_device_authorization()
|
||||
.await
|
||||
.expect("device authorization payload");
|
||||
|
||||
assert_eq!(auth.user_code, "ABCD-EFGH");
|
||||
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
||||
assert!(auth.expires_at > Utc::now());
|
||||
device_mock.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_reports_pending() {
|
||||
let server = MockServer::start_async().await;
|
||||
let pending = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains(
|
||||
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
||||
)
|
||||
.body_contains("device_code=device-123");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "authorization_pending"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(5));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
pending.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_applies_slow_down_backoff() {
|
||||
let server = MockServer::start_async().await;
|
||||
let slow = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(400)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"error": "slow_down"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
match result {
|
||||
DevicePollState::Pending { retry_in } => {
|
||||
assert_eq!(retry_in, StdDuration::from_secs(10));
|
||||
}
|
||||
other => panic!("expected pending state, got {other:?}"),
|
||||
}
|
||||
|
||||
slow.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn poll_device_token_returns_token_when_authorized() {
|
||||
let server = MockServer::start_async().await;
|
||||
let token = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST).path("/token");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-abc",
|
||||
"refresh_token": "refresh-xyz",
|
||||
"expires_in": 3600,
|
||||
"token_type": "Bearer",
|
||||
"scope": "repo user"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let auth = sample_device_authorization();
|
||||
|
||||
let result = client.poll_device_token(&auth).await.expect("poll result");
|
||||
let token_info = match result {
|
||||
DevicePollState::Complete(token) => token,
|
||||
other => panic!("expected completion, got {other:?}"),
|
||||
};
|
||||
|
||||
assert_eq!(token_info.access_token, "token-abc");
|
||||
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
||||
assert!(token_info.expires_at.is_some());
|
||||
token.assert_async().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn refresh_token_roundtrip() {
|
||||
let server = MockServer::start_async().await;
|
||||
let refresh = server
|
||||
.mock_async(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/token")
|
||||
.body_contains("grant_type=refresh_token")
|
||||
.body_contains("refresh_token=old-refresh");
|
||||
then.status(200)
|
||||
.header("content-type", "application/json")
|
||||
.json_body(json!({
|
||||
"access_token": "token-new",
|
||||
"refresh_token": "refresh-new",
|
||||
"expires_in": 1200,
|
||||
"token_type": "Bearer"
|
||||
}));
|
||||
})
|
||||
.await;
|
||||
|
||||
let config = config_for(&server);
|
||||
let client = OAuthClient::new(config).expect("client");
|
||||
let token = client
|
||||
.refresh_token("old-refresh")
|
||||
.await
|
||||
.expect("refresh response");
|
||||
|
||||
assert_eq!(token.access_token, "token-new");
|
||||
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
||||
assert!(token.expires_at.is_some());
|
||||
refresh.assert_async().await;
|
||||
}
|
||||
}
|
||||
@@ -1,514 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use futures::stream::{FuturesUnordered, StreamExt};
|
||||
use log::{debug, warn};
|
||||
use serde_json::Value;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::{Error, Result};
|
||||
|
||||
use super::{
|
||||
GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderStatus, ProviderType,
|
||||
};
|
||||
|
||||
/// Model information annotated with the originating provider metadata.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AnnotatedModelInfo {
|
||||
pub provider_id: String,
|
||||
pub provider_status: ProviderStatus,
|
||||
pub model: ModelInfo,
|
||||
}
|
||||
|
||||
/// Coordinates multiple [`ModelProvider`] implementations and tracks their
|
||||
/// health state.
|
||||
pub struct ProviderManager {
|
||||
providers: RwLock<HashMap<String, Arc<dyn ModelProvider>>>,
|
||||
status_cache: RwLock<HashMap<String, ProviderStatus>>,
|
||||
last_health_check: RwLock<Option<Instant>>,
|
||||
health_cache_ttl: Duration,
|
||||
}
|
||||
|
||||
impl ProviderManager {
|
||||
/// Construct a new manager using the supplied configuration. Providers
|
||||
/// defined in the configuration start with a `RequiresSetup` status so
|
||||
/// that frontends can surface incomplete configuration to users.
|
||||
pub fn new(config: &Config) -> Self {
|
||||
let mut status_cache = HashMap::new();
|
||||
for provider_id in config.providers.keys() {
|
||||
status_cache.insert(provider_id.clone(), ProviderStatus::RequiresSetup);
|
||||
}
|
||||
|
||||
// Use configured TTL (default 30 seconds) to reduce health check load
|
||||
let health_cache_ttl = config.general.health_check_ttl();
|
||||
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(status_cache),
|
||||
last_health_check: RwLock::new(None),
|
||||
health_cache_ttl,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider instance with the manager.
|
||||
pub async fn register_provider(&self, provider: Arc<dyn ModelProvider>) {
|
||||
let provider_id = provider.metadata().id.clone();
|
||||
debug!("registering provider {}", provider_id);
|
||||
|
||||
self.providers
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.clone(), provider);
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id, ProviderStatus::Unavailable);
|
||||
}
|
||||
|
||||
/// Return a stream by routing the request to the designated provider.
|
||||
pub async fn generate(
|
||||
&self,
|
||||
provider_id: &str,
|
||||
request: GenerateRequest,
|
||||
) -> Result<GenerateStream> {
|
||||
let provider = {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
.ok_or_else(|| Error::Config(format!("provider '{provider_id}' not registered")))?;
|
||||
|
||||
match provider.generate_stream(request).await {
|
||||
Ok(stream) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Available);
|
||||
Ok(stream)
|
||||
}
|
||||
Err(err) => {
|
||||
self.status_cache
|
||||
.write()
|
||||
.await
|
||||
.insert(provider_id.to_string(), ProviderStatus::Unavailable);
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// List models across all providers, updating provider status along the way.
|
||||
pub async fn list_all_models(&self) -> Result<Vec<AnnotatedModelInfo>> {
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let log_id = provider_id.clone();
|
||||
let mut status = ProviderStatus::Unavailable;
|
||||
let mut models = Vec::new();
|
||||
|
||||
match provider.health_check().await {
|
||||
Ok(health) => {
|
||||
status = health;
|
||||
if matches!(status, ProviderStatus::Available) {
|
||||
match provider.list_models().await {
|
||||
Ok(list) => {
|
||||
models = list;
|
||||
}
|
||||
Err(err) => {
|
||||
status = ProviderStatus::Unavailable;
|
||||
warn!("listing models failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", log_id, err);
|
||||
}
|
||||
}
|
||||
|
||||
(provider_id, status, models)
|
||||
});
|
||||
}
|
||||
|
||||
let mut annotated = Vec::new();
|
||||
let mut status_updates = HashMap::new();
|
||||
|
||||
while let Some((provider_id, status, models)) = tasks.next().await {
|
||||
status_updates.insert(provider_id.clone(), status);
|
||||
for model in models {
|
||||
annotated.push(AnnotatedModelInfo {
|
||||
provider_id: provider_id.clone(),
|
||||
provider_status: status,
|
||||
model,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in status_updates {
|
||||
guard.insert(provider_id, status);
|
||||
}
|
||||
}
|
||||
|
||||
enrich_model_metadata(&mut annotated);
|
||||
Ok(annotated)
|
||||
}
|
||||
|
||||
/// Refresh the health of all registered providers in parallel, returning
|
||||
/// the latest status snapshot. Results are cached for the configured TTL
|
||||
/// to reduce provider load.
|
||||
pub async fn refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||
// Check if cache is still fresh
|
||||
{
|
||||
let last_check = self.last_health_check.read().await;
|
||||
if let Some(instant) = *last_check && instant.elapsed() < self.health_cache_ttl {
|
||||
// Return cached status without performing checks
|
||||
debug!("returning cached health status (TTL not expired)");
|
||||
return self.status_cache.read().await.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Cache expired or first check - perform actual health checks
|
||||
debug!("cache expired, performing health checks");
|
||||
let providers: Vec<(String, Arc<dyn ModelProvider>)> = {
|
||||
let guard = self.providers.read().await;
|
||||
guard
|
||||
.iter()
|
||||
.map(|(id, provider)| (id.clone(), Arc::clone(provider)))
|
||||
.collect()
|
||||
};
|
||||
|
||||
let mut tasks = FuturesUnordered::new();
|
||||
for (provider_id, provider) in providers {
|
||||
tasks.push(async move {
|
||||
let status = match provider.health_check().await {
|
||||
Ok(status) => status,
|
||||
Err(err) => {
|
||||
warn!("health check failed for provider {}: {}", provider_id, err);
|
||||
ProviderStatus::Unavailable
|
||||
}
|
||||
};
|
||||
(provider_id, status)
|
||||
});
|
||||
}
|
||||
|
||||
let mut updates = HashMap::new();
|
||||
while let Some((provider_id, status)) = tasks.next().await {
|
||||
updates.insert(provider_id, status);
|
||||
}
|
||||
|
||||
{
|
||||
let mut guard = self.status_cache.write().await;
|
||||
for (provider_id, status) in &updates {
|
||||
guard.insert(provider_id.clone(), *status);
|
||||
}
|
||||
}
|
||||
|
||||
// Update cache timestamp
|
||||
*self.last_health_check.write().await = Some(Instant::now());
|
||||
|
||||
updates
|
||||
}
|
||||
|
||||
/// Force a health check refresh, bypassing the cache. This is useful
|
||||
/// when an immediate status update is required.
|
||||
pub async fn force_refresh_health(&self) -> HashMap<String, ProviderStatus> {
|
||||
debug!("forcing health check refresh (bypassing cache)");
|
||||
*self.last_health_check.write().await = None;
|
||||
self.refresh_health().await
|
||||
}
|
||||
|
||||
/// Return the provider instance for an identifier.
|
||||
pub async fn get_provider(&self, provider_id: &str) -> Option<Arc<dyn ModelProvider>> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.get(provider_id).cloned()
|
||||
}
|
||||
|
||||
/// List the registered provider identifiers.
|
||||
pub async fn provider_ids(&self) -> Vec<String> {
|
||||
let guard = self.providers.read().await;
|
||||
guard.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Retrieve the last known status for a provider.
|
||||
pub async fn provider_status(&self, provider_id: &str) -> Option<ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.get(provider_id).copied()
|
||||
}
|
||||
|
||||
/// Snapshot the currently cached statuses.
|
||||
pub async fn provider_statuses(&self) -> HashMap<String, ProviderStatus> {
|
||||
let guard = self.status_cache.read().await;
|
||||
guard.clone()
|
||||
}
|
||||
}
|
||||
|
||||
fn enrich_model_metadata(models: &mut [AnnotatedModelInfo]) {
|
||||
let mut name_counts: HashMap<String, usize> = HashMap::new();
|
||||
for info in models.iter() {
|
||||
*name_counts.entry(info.model.name.clone()).or_default() += 1;
|
||||
}
|
||||
|
||||
for info in models.iter_mut() {
|
||||
let provider_tag = provider_tag_for(&info.provider_id);
|
||||
info.model
|
||||
.metadata
|
||||
.insert("provider_tag".into(), Value::String(provider_tag.clone()));
|
||||
|
||||
let scope_label = provider_scope_label(info.model.provider.provider_type);
|
||||
info.model.metadata.insert(
|
||||
"provider_scope".into(),
|
||||
Value::String(scope_label.to_string()),
|
||||
);
|
||||
info.model.metadata.insert(
|
||||
"provider_display_name".into(),
|
||||
Value::String(info.model.provider.name.clone()),
|
||||
);
|
||||
|
||||
let display_name = if name_counts
|
||||
.get(&info.model.name)
|
||||
.is_some_and(|count| *count > 1)
|
||||
{
|
||||
let suffix = scope_label;
|
||||
let base = info.model.name.trim();
|
||||
if base.ends_with(&format!("· {}", suffix)) {
|
||||
base.to_string()
|
||||
} else {
|
||||
format!("{base} · {suffix}")
|
||||
}
|
||||
} else {
|
||||
info.model.name.clone()
|
||||
};
|
||||
|
||||
info.model
|
||||
.metadata
|
||||
.insert("display_name".into(), Value::String(display_name));
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_tag_for(provider_id: &str) -> String {
|
||||
let normalized = provider_id.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"ollama" | "ollama_local" => "ollama".to_string(),
|
||||
"ollama_cloud" => "ollama-cloud".to_string(),
|
||||
other => other.replace('_', "-"),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_scope_label(provider_type: ProviderType) -> &'static str {
|
||||
match provider_type {
|
||||
ProviderType::Local => "local",
|
||||
ProviderType::Cloud => "cloud",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{Error, provider::ProviderMetadata};
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StaticProvider {
|
||||
metadata: ProviderMetadata,
|
||||
models: Vec<ModelInfo>,
|
||||
status: ProviderStatus,
|
||||
}
|
||||
|
||||
impl StaticProvider {
|
||||
fn new(
|
||||
id: &str,
|
||||
name: &str,
|
||||
provider_type: ProviderType,
|
||||
status: ProviderStatus,
|
||||
models: Vec<ModelInfo>,
|
||||
) -> Self {
|
||||
let metadata = ProviderMetadata::new(id, name, provider_type, false);
|
||||
let mut models = models;
|
||||
for model in &mut models {
|
||||
model.provider = metadata.clone();
|
||||
}
|
||||
let mut metadata = metadata;
|
||||
metadata
|
||||
.metadata
|
||||
.insert("test".into(), Value::String("true".into()));
|
||||
Self {
|
||||
metadata,
|
||||
models,
|
||||
status,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ModelProvider for StaticProvider {
|
||||
fn metadata(&self) -> &ProviderMetadata {
|
||||
&self.metadata
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<ProviderStatus> {
|
||||
Ok(self.status)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
Ok(self.models.clone())
|
||||
}
|
||||
|
||||
async fn generate_stream(&self, _request: GenerateRequest) -> Result<GenerateStream> {
|
||||
Err(Error::NotImplemented(
|
||||
"streaming not implemented in StaticProvider".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn model(name: &str) -> ModelInfo {
|
||||
ModelInfo {
|
||||
name: name.to_string(),
|
||||
size_bytes: None,
|
||||
capabilities: Vec::new(),
|
||||
description: None,
|
||||
provider: ProviderMetadata::new("unused", "Unused", ProviderType::Local, false),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregates_local_provider_models() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = StaticProvider::new(
|
||||
"ollama_local",
|
||||
"Ollama Local",
|
||||
ProviderType::Local,
|
||||
ProviderStatus::Available,
|
||||
vec![model("qwen3:8b")],
|
||||
);
|
||||
manager.register_provider(Arc::new(provider)).await;
|
||||
|
||||
let models = manager.list_all_models().await.unwrap();
|
||||
assert_eq!(models.len(), 1);
|
||||
let entry = &models[0];
|
||||
assert_eq!(entry.provider_id, "ollama_local");
|
||||
assert_eq!(entry.provider_status, ProviderStatus::Available);
|
||||
assert_eq!(
|
||||
entry
|
||||
.model
|
||||
.metadata
|
||||
.get("provider_tag")
|
||||
.and_then(Value::as_str),
|
||||
Some("ollama")
|
||||
);
|
||||
assert_eq!(
|
||||
entry
|
||||
.model
|
||||
.metadata
|
||||
.get("display_name")
|
||||
.and_then(Value::as_str),
|
||||
Some("qwen3:8b")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn aggregates_cloud_provider_models() {
|
||||
let manager = ProviderManager::default();
|
||||
let provider = StaticProvider::new(
|
||||
"ollama_cloud",
|
||||
"Ollama Cloud",
|
||||
ProviderType::Cloud,
|
||||
ProviderStatus::Available,
|
||||
vec![model("qwen3:0.5b-cloud")],
|
||||
);
|
||||
manager.register_provider(Arc::new(provider)).await;
|
||||
|
||||
let models = manager.list_all_models().await.unwrap();
|
||||
assert_eq!(models.len(), 1);
|
||||
let entry = &models[0];
|
||||
assert_eq!(
|
||||
entry
|
||||
.model
|
||||
.metadata
|
||||
.get("provider_tag")
|
||||
.and_then(Value::as_str),
|
||||
Some("ollama-cloud")
|
||||
);
|
||||
assert_eq!(
|
||||
entry
|
||||
.model
|
||||
.metadata
|
||||
.get("display_name")
|
||||
.and_then(Value::as_str),
|
||||
Some("qwen3:0.5b-cloud")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn deduplicates_model_names_with_provider_suffix() {
|
||||
let manager = ProviderManager::default();
|
||||
let local = StaticProvider::new(
|
||||
"ollama_local",
|
||||
"Ollama Local",
|
||||
ProviderType::Local,
|
||||
ProviderStatus::Available,
|
||||
vec![model("qwen3:8b")],
|
||||
);
|
||||
let cloud = StaticProvider::new(
|
||||
"ollama_cloud",
|
||||
"Ollama Cloud",
|
||||
ProviderType::Cloud,
|
||||
ProviderStatus::Available,
|
||||
vec![model("qwen3:8b")],
|
||||
);
|
||||
manager.register_provider(Arc::new(local)).await;
|
||||
manager.register_provider(Arc::new(cloud)).await;
|
||||
|
||||
let models = manager.list_all_models().await.unwrap();
|
||||
|
||||
let local_entry = models
|
||||
.iter()
|
||||
.find(|entry| entry.provider_id == "ollama_local")
|
||||
.expect("local provider entry");
|
||||
let cloud_entry = models
|
||||
.iter()
|
||||
.find(|entry| entry.provider_id == "ollama_cloud")
|
||||
.expect("cloud provider entry");
|
||||
|
||||
assert_eq!(
|
||||
local_entry
|
||||
.model
|
||||
.metadata
|
||||
.get("display_name")
|
||||
.and_then(Value::as_str),
|
||||
Some("qwen3:8b · local")
|
||||
);
|
||||
assert_eq!(
|
||||
cloud_entry
|
||||
.model
|
||||
.metadata
|
||||
.get("display_name")
|
||||
.and_then(Value::as_str),
|
||||
Some("qwen3:8b · cloud")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ProviderManager {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
providers: RwLock::new(HashMap::new()),
|
||||
status_cache: RwLock::new(HashMap::new()),
|
||||
last_health_check: RwLock::new(None),
|
||||
health_cache_ttl: Duration::from_secs(30),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
//! Unified provider abstraction layer.
|
||||
//!
|
||||
//! This module defines the async [`ModelProvider`] trait that all model
|
||||
//! backends implement, together with a small suite of shared data structures
|
||||
//! used for model discovery and streaming generation. The [`ProviderManager`]
|
||||
//! orchestrates multiple providers and coordinates their health state.
|
||||
|
||||
mod manager;
|
||||
mod types;
|
||||
|
||||
use std::pin::Pin;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
|
||||
pub use self::{manager::*, types::*};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// Convenience alias for the stream type yielded by [`ModelProvider::generate_stream`].
|
||||
pub type GenerateStream = Pin<Box<dyn Stream<Item = Result<GenerateChunk>> + Send + 'static>>;
|
||||
|
||||
#[async_trait]
|
||||
pub trait ModelProvider: Send + Sync {
|
||||
/// Returns descriptive metadata about the provider.
|
||||
fn metadata(&self) -> &ProviderMetadata;
|
||||
|
||||
/// Check the current health state for the provider.
|
||||
async fn health_check(&self) -> Result<ProviderStatus>;
|
||||
|
||||
/// List all models available through the provider.
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>>;
|
||||
|
||||
/// Acquire a streaming response for a generation request.
|
||||
async fn generate_stream(&self, request: GenerateRequest) -> Result<GenerateStream>;
|
||||
}
|
||||
@@ -1,205 +0,0 @@
|
||||
//! Shared types used by the unified provider abstraction layer.
|
||||
|
||||
use std::{collections::HashMap, fmt};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Categorises providers so the UI can distinguish between local and hosted
|
||||
/// backends.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderType {
|
||||
Local,
|
||||
Cloud,
|
||||
}
|
||||
|
||||
/// Represents the current availability state for a provider.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderStatus {
|
||||
Available,
|
||||
Unavailable,
|
||||
RequiresSetup,
|
||||
}
|
||||
|
||||
/// High-level categories for provider failures.
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub enum ProviderErrorKind {
|
||||
Unauthorized,
|
||||
RateLimited,
|
||||
Unavailable,
|
||||
Timeout,
|
||||
InvalidRequest,
|
||||
ModelNotFound,
|
||||
Network,
|
||||
Protocol,
|
||||
Unknown,
|
||||
}
|
||||
|
||||
impl fmt::Display for ProviderErrorKind {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
ProviderErrorKind::Unauthorized => "unauthorized",
|
||||
ProviderErrorKind::RateLimited => "rate limited",
|
||||
ProviderErrorKind::Unavailable => "unavailable",
|
||||
ProviderErrorKind::Timeout => "timed out",
|
||||
ProviderErrorKind::InvalidRequest => "invalid request",
|
||||
ProviderErrorKind::ModelNotFound => "model not found",
|
||||
ProviderErrorKind::Network => "network error",
|
||||
ProviderErrorKind::Protocol => "protocol error",
|
||||
ProviderErrorKind::Unknown => "unknown failure",
|
||||
};
|
||||
write!(f, "{label}")
|
||||
}
|
||||
}
|
||||
|
||||
/// Structured provider failure description used for UI and logs.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ProviderError {
|
||||
pub provider_id: Option<String>,
|
||||
pub kind: ProviderErrorKind,
|
||||
pub message: String,
|
||||
#[serde(default)]
|
||||
pub detail: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderError {
|
||||
/// Construct a new provider error with the given category and message.
|
||||
pub fn new(kind: ProviderErrorKind, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
provider_id: None,
|
||||
kind,
|
||||
message: message.into(),
|
||||
detail: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach the provider identifier to the failure.
|
||||
pub fn with_provider(mut self, provider_id: impl Into<String>) -> Self {
|
||||
self.provider_id = Some(provider_id.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Attach a detailed description to the failure.
|
||||
pub fn with_detail(mut self, detail: impl Into<String>) -> Self {
|
||||
let text = detail.into();
|
||||
if !text.trim().is_empty() {
|
||||
self.detail = Some(text);
|
||||
}
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for ProviderError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match (&self.detail, &self.provider_id) {
|
||||
(Some(detail), Some(provider)) => {
|
||||
write!(f, "{provider}: {} ({detail})", self.message)
|
||||
}
|
||||
(Some(detail), None) => write!(f, "{} ({detail})", self.message),
|
||||
(None, Some(provider)) => write!(f, "{provider}: {}", self.message),
|
||||
(None, None) => write!(f, "{}", self.message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Describes core metadata for a provider implementation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct ProviderMetadata {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub provider_type: ProviderType,
|
||||
pub requires_auth: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata {
|
||||
/// Construct a new metadata instance for a provider.
|
||||
pub fn new(
|
||||
id: impl Into<String>,
|
||||
name: impl Into<String>,
|
||||
provider_type: ProviderType,
|
||||
requires_auth: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: id.into(),
|
||||
name: name.into(),
|
||||
provider_type,
|
||||
requires_auth,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Information about a model that can be displayed to users.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ModelInfo {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub size_bytes: Option<u64>,
|
||||
#[serde(default)]
|
||||
pub capabilities: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
pub provider: ProviderMetadata,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Unified request for streaming text generation across providers.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateRequest {
|
||||
pub model: String,
|
||||
#[serde(default)]
|
||||
pub prompt: Option<String>,
|
||||
#[serde(default)]
|
||||
pub context: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub parameters: HashMap<String, Value>,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateRequest {
|
||||
/// Helper for building a request from the minimum required fields.
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
prompt: None,
|
||||
context: Vec::new(),
|
||||
parameters: HashMap::new(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Streamed chunk of generation output from a model.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct GenerateChunk {
|
||||
#[serde(default)]
|
||||
pub text: Option<String>,
|
||||
#[serde(default)]
|
||||
pub is_final: bool,
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
impl GenerateChunk {
|
||||
/// Construct a new chunk with the provided text payload.
|
||||
pub fn from_text(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
text: Some(text.into()),
|
||||
is_final: false,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark the chunk as the terminal item in a stream.
|
||||
pub fn final_chunk() -> Self {
|
||||
Self {
|
||||
text: None,
|
||||
is_final: true,
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! Built-in LLM provider implementations.
|
||||
//!
|
||||
//! Each provider integration lives in its own module so that maintenance
|
||||
//! stays focused and configuration remains clear.
|
||||
|
||||
pub mod ollama;
|
||||
|
||||
pub use ollama::OllamaProvider;
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,157 +0,0 @@
|
||||
//! Router for managing multiple providers and routing requests
|
||||
|
||||
use crate::{Result, llm::*, types::*};
|
||||
use anyhow::anyhow;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A router that can distribute requests across multiple providers
|
||||
pub struct Router {
|
||||
registry: ProviderRegistry,
|
||||
routing_rules: Vec<RoutingRule>,
|
||||
default_provider: Option<String>,
|
||||
}
|
||||
|
||||
/// A rule for routing requests to specific providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RoutingRule {
|
||||
/// Pattern to match against model names
|
||||
pub model_pattern: String,
|
||||
/// Provider to route to
|
||||
pub provider: String,
|
||||
/// Priority (higher numbers are checked first)
|
||||
pub priority: u32,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Create a new router
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
registry: ProviderRegistry::new(),
|
||||
routing_rules: Vec::new(),
|
||||
default_provider: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a provider with the router
|
||||
pub fn register_provider<P: LlmProvider + 'static>(&mut self, provider: P) {
|
||||
self.registry.register(provider);
|
||||
}
|
||||
|
||||
/// Set the default provider
|
||||
pub fn set_default_provider(&mut self, provider_name: String) {
|
||||
self.default_provider = Some(provider_name);
|
||||
}
|
||||
|
||||
/// Add a routing rule
|
||||
pub fn add_routing_rule(&mut self, rule: RoutingRule) {
|
||||
self.routing_rules.push(rule);
|
||||
// Sort by priority (descending)
|
||||
self.routing_rules
|
||||
.sort_by(|a, b| b.priority.cmp(&a.priority));
|
||||
}
|
||||
|
||||
/// Route a request to the appropriate provider
|
||||
pub async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.send_prompt(request).await
|
||||
}
|
||||
|
||||
/// Route a streaming request to the appropriate provider
|
||||
pub async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> {
|
||||
let provider = self.find_provider_for_model(&request.model)?;
|
||||
provider.stream_prompt(request).await
|
||||
}
|
||||
|
||||
/// List all available models from all providers
|
||||
pub async fn list_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
self.registry.list_all_models().await
|
||||
}
|
||||
|
||||
/// Find the appropriate provider for a given model
|
||||
fn find_provider_for_model(&self, model: &str) -> Result<Arc<dyn Provider>> {
|
||||
// Check routing rules first
|
||||
for rule in &self.routing_rules {
|
||||
if !self.matches_pattern(&rule.model_pattern, model) {
|
||||
continue;
|
||||
}
|
||||
if let Some(provider) = self.registry.get(&rule.provider) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default provider
|
||||
if let Some(provider) = self
|
||||
.default_provider
|
||||
.as_ref()
|
||||
.and_then(|default| self.registry.get(default))
|
||||
{
|
||||
return Ok(provider);
|
||||
}
|
||||
|
||||
// If no default, try to find any provider that has this model
|
||||
// This is a fallback for cases where routing isn't configured
|
||||
for provider_name in self.registry.list_providers() {
|
||||
if let Some(provider) = self.registry.get(&provider_name) {
|
||||
return Ok(provider);
|
||||
}
|
||||
}
|
||||
|
||||
Err(crate::Error::Provider(anyhow!(
|
||||
"No provider found for model: {}",
|
||||
model
|
||||
)))
|
||||
}
|
||||
|
||||
/// Check if a model name matches a pattern
|
||||
fn matches_pattern(&self, pattern: &str, model: &str) -> bool {
|
||||
// Simple pattern matching for now
|
||||
// Could be extended to support more complex patterns
|
||||
if pattern == "*" {
|
||||
return true;
|
||||
}
|
||||
|
||||
if let Some(prefix) = pattern.strip_suffix('*') {
|
||||
return model.starts_with(prefix);
|
||||
}
|
||||
|
||||
if let Some(suffix) = pattern.strip_prefix('*') {
|
||||
return model.ends_with(suffix);
|
||||
}
|
||||
|
||||
pattern == model
|
||||
}
|
||||
|
||||
/// Get routing configuration
|
||||
pub fn get_routing_rules(&self) -> &[RoutingRule] {
|
||||
&self.routing_rules
|
||||
}
|
||||
|
||||
/// Get the default provider name
|
||||
pub fn get_default_provider(&self) -> Option<&str> {
|
||||
self.default_provider.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Router {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_pattern_matching() {
|
||||
let router = Router::new();
|
||||
|
||||
assert!(router.matches_pattern("*", "any-model"));
|
||||
assert!(router.matches_pattern("gpt*", "gpt-4"));
|
||||
assert!(router.matches_pattern("gpt*", "gpt-3.5-turbo"));
|
||||
assert!(!router.matches_pattern("gpt*", "claude-3"));
|
||||
assert!(router.matches_pattern("*:latest", "llama2:latest"));
|
||||
assert!(router.matches_pattern("exact-match", "exact-match"));
|
||||
assert!(!router.matches_pattern("exact-match", "different-model"));
|
||||
}
|
||||
}
|
||||
@@ -1,216 +0,0 @@
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{Context, Result, bail};
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Configuration options for sandboxed process execution.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SandboxConfig {
|
||||
pub allow_network: bool,
|
||||
pub allow_paths: Vec<PathBuf>,
|
||||
pub readonly_paths: Vec<PathBuf>,
|
||||
pub timeout_seconds: u64,
|
||||
pub max_memory_mb: u64,
|
||||
}
|
||||
|
||||
impl Default for SandboxConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_network: false,
|
||||
allow_paths: Vec::new(),
|
||||
readonly_paths: Vec::new(),
|
||||
timeout_seconds: 30,
|
||||
max_memory_mb: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around a bubblewrap sandbox instance.
|
||||
///
|
||||
/// Memory limits are enforced via:
|
||||
/// - bwrap's --rlimit-as (version >= 0.12.0)
|
||||
/// - prlimit wrapper (fallback for older bwrap versions)
|
||||
/// - timeout mechanism (always enforced as last resort)
|
||||
pub struct SandboxedProcess {
|
||||
temp_dir: TempDir,
|
||||
config: SandboxConfig,
|
||||
}
|
||||
|
||||
impl SandboxedProcess {
|
||||
pub fn new(config: SandboxConfig) -> Result<Self> {
|
||||
let temp_dir = TempDir::new().context("Failed to create temp directory")?;
|
||||
|
||||
which::which("bwrap")
|
||||
.context("bubblewrap not found. Install with: sudo apt install bubblewrap")?;
|
||||
|
||||
Ok(Self { temp_dir, config })
|
||||
}
|
||||
|
||||
pub fn execute(&self, command: &str, args: &[&str]) -> Result<SandboxResult> {
|
||||
let supports_rlimit = self.supports_rlimit_as();
|
||||
let use_prlimit = !supports_rlimit && which::which("prlimit").is_ok();
|
||||
|
||||
let mut cmd = if use_prlimit {
|
||||
// Use prlimit wrapper for older bwrap versions
|
||||
let mut prlimit_cmd = Command::new("prlimit");
|
||||
let memory_limit_bytes = self
|
||||
.config
|
||||
.max_memory_mb
|
||||
.saturating_mul(1024)
|
||||
.saturating_mul(1024);
|
||||
prlimit_cmd.arg(format!("--as={}", memory_limit_bytes));
|
||||
prlimit_cmd.arg("bwrap");
|
||||
prlimit_cmd
|
||||
} else {
|
||||
Command::new("bwrap")
|
||||
};
|
||||
|
||||
cmd.args(["--unshare-all", "--die-with-parent", "--new-session"]);
|
||||
|
||||
if self.config.allow_network {
|
||||
cmd.arg("--share-net");
|
||||
} else {
|
||||
cmd.arg("--unshare-net");
|
||||
}
|
||||
|
||||
cmd.args(["--proc", "/proc", "--dev", "/dev", "--tmpfs", "/tmp"]);
|
||||
|
||||
// Bind essential system paths readonly for executables and libraries
|
||||
let system_paths = ["/usr", "/bin", "/lib", "/lib64", "/etc"];
|
||||
for sys_path in &system_paths {
|
||||
let path = std::path::Path::new(sys_path);
|
||||
if path.exists() {
|
||||
cmd.arg("--ro-bind").arg(sys_path).arg(sys_path);
|
||||
}
|
||||
}
|
||||
|
||||
// Bind /run for DNS resolution (resolv.conf may be a symlink to /run/systemd/resolve/*)
|
||||
if std::path::Path::new("/run").exists() {
|
||||
cmd.arg("--ro-bind").arg("/run").arg("/run");
|
||||
}
|
||||
|
||||
for path in &self.config.allow_paths {
|
||||
let path_host = path.to_string_lossy().into_owned();
|
||||
let path_guest = path_host.clone();
|
||||
cmd.arg("--bind").arg(&path_host).arg(&path_guest);
|
||||
}
|
||||
|
||||
for path in &self.config.readonly_paths {
|
||||
let path_host = path.to_string_lossy().into_owned();
|
||||
let path_guest = path_host.clone();
|
||||
cmd.arg("--ro-bind").arg(&path_host).arg(&path_guest);
|
||||
}
|
||||
|
||||
let work_dir = self.temp_dir.path().to_string_lossy().into_owned();
|
||||
cmd.arg("--bind").arg(&work_dir).arg("/work");
|
||||
cmd.arg("--chdir").arg("/work");
|
||||
|
||||
// Add memory limits via bwrap's --rlimit-as if supported (version >= 0.12.0)
|
||||
// If not supported, we use prlimit wrapper (set earlier)
|
||||
if supports_rlimit && !use_prlimit {
|
||||
let memory_limit_bytes = self
|
||||
.config
|
||||
.max_memory_mb
|
||||
.saturating_mul(1024)
|
||||
.saturating_mul(1024);
|
||||
let memory_soft = memory_limit_bytes.to_string();
|
||||
let memory_hard = memory_limit_bytes.to_string();
|
||||
cmd.arg("--rlimit-as").arg(&memory_soft).arg(&memory_hard);
|
||||
}
|
||||
|
||||
cmd.arg(command);
|
||||
cmd.args(args);
|
||||
|
||||
let start = Instant::now();
|
||||
let timeout = Duration::from_secs(self.config.timeout_seconds);
|
||||
|
||||
// Spawn the process instead of waiting immediately
|
||||
let mut child = cmd
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.context("Failed to spawn sandboxed command")?;
|
||||
|
||||
let mut was_timeout = false;
|
||||
|
||||
// Wait for the child with timeout
|
||||
let output = loop {
|
||||
match child.try_wait() {
|
||||
Ok(Some(_status)) => {
|
||||
// Process exited
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.context("Failed to collect process output")?;
|
||||
break output;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Process still running, check timeout
|
||||
if start.elapsed() >= timeout {
|
||||
// Timeout exceeded, kill the process
|
||||
was_timeout = true;
|
||||
child.kill().context("Failed to kill timed-out process")?;
|
||||
// Wait for the killed process to exit
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.context("Failed to collect output from killed process")?;
|
||||
break output;
|
||||
}
|
||||
// Sleep briefly before checking again
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
Err(e) => {
|
||||
bail!("Failed to check process status: {}", e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
Ok(SandboxResult {
|
||||
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
|
||||
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
|
||||
exit_code: output.status.code().unwrap_or(-1),
|
||||
duration,
|
||||
was_timeout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if bubblewrap supports --rlimit-as option (version >= 0.12.0)
|
||||
fn supports_rlimit_as(&self) -> bool {
|
||||
// Try to get bwrap version
|
||||
let output = Command::new("bwrap").arg("--version").output();
|
||||
|
||||
if let Ok(output) = output {
|
||||
let version_str = String::from_utf8_lossy(&output.stdout);
|
||||
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
||||
return version_str
|
||||
.split_whitespace()
|
||||
.last()
|
||||
.and_then(|part| {
|
||||
part.split_once('.').and_then(|(major, rest)| {
|
||||
rest.split_once('.').and_then(|(minor, _)| {
|
||||
let maj = major.parse::<u32>().ok()?;
|
||||
let min = minor.parse::<u32>().ok()?;
|
||||
Some((maj, min))
|
||||
})
|
||||
})
|
||||
})
|
||||
.map(|(maj, min)| maj > 0 || (maj == 0 && min >= 12))
|
||||
.unwrap_or(false);
|
||||
}
|
||||
|
||||
// If we can't determine the version, assume it doesn't support it (safer default)
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SandboxResult {
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub exit_code: i32,
|
||||
pub duration: Duration,
|
||||
pub was_timeout: bool,
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,199 +0,0 @@
|
||||
//! Shared application state types used across TUI frontends.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
/// High-level application state reported by the UI loop.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum AppState {
|
||||
Running,
|
||||
Quit,
|
||||
}
|
||||
|
||||
/// Vim-style input modes supported by the TUI.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum InputMode {
|
||||
Normal,
|
||||
Editing,
|
||||
ProviderSelection,
|
||||
ModelSelection,
|
||||
Help,
|
||||
Visual,
|
||||
Command,
|
||||
SessionBrowser,
|
||||
ThemeBrowser,
|
||||
RepoSearch,
|
||||
SymbolSearch,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
InputMode::Normal => "Normal",
|
||||
InputMode::Editing => "Editing",
|
||||
InputMode::ModelSelection => "Model",
|
||||
InputMode::ProviderSelection => "Provider",
|
||||
InputMode::Help => "Help",
|
||||
InputMode::Visual => "Visual",
|
||||
InputMode::Command => "Command",
|
||||
InputMode::SessionBrowser => "Sessions",
|
||||
InputMode::ThemeBrowser => "Themes",
|
||||
InputMode::RepoSearch => "Search",
|
||||
InputMode::SymbolSearch => "Symbols",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents which panel is currently focused in the TUI layout.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum FocusedPanel {
|
||||
Files,
|
||||
Chat,
|
||||
Thinking,
|
||||
Input,
|
||||
Code,
|
||||
}
|
||||
|
||||
/// Auto-scroll state manager for scrollable panels.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutoScroll {
|
||||
pub scroll: usize,
|
||||
pub content_len: usize,
|
||||
pub stick_to_bottom: bool,
|
||||
}
|
||||
|
||||
impl Default for AutoScroll {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
scroll: 0,
|
||||
content_len: 0,
|
||||
stick_to_bottom: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutoScroll {
|
||||
/// Update scroll position based on viewport height.
|
||||
pub fn on_viewport(&mut self, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h);
|
||||
if self.stick_to_bottom {
|
||||
self.scroll = max;
|
||||
} else {
|
||||
self.scroll = self.scroll.min(max);
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle user scroll input.
|
||||
pub fn on_user_scroll(&mut self, delta: isize, viewport_h: usize) {
|
||||
let max = self.content_len.saturating_sub(viewport_h) as isize;
|
||||
let s = (self.scroll as isize + delta).clamp(0, max) as usize;
|
||||
self.scroll = s;
|
||||
self.stick_to_bottom = s as isize == max;
|
||||
}
|
||||
|
||||
pub fn scroll_half_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = (viewport_h / 2) as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_half_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -((viewport_h / 2) as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_full_page_down(&mut self, viewport_h: usize) {
|
||||
let delta = viewport_h as isize;
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn scroll_full_page_up(&mut self, viewport_h: usize) {
|
||||
let delta = -(viewport_h as isize);
|
||||
self.on_user_scroll(delta, viewport_h);
|
||||
}
|
||||
|
||||
pub fn jump_to_top(&mut self) {
|
||||
self.scroll = 0;
|
||||
self.stick_to_bottom = false;
|
||||
}
|
||||
|
||||
pub fn jump_to_bottom(&mut self, viewport_h: usize) {
|
||||
self.stick_to_bottom = true;
|
||||
self.on_viewport(viewport_h);
|
||||
}
|
||||
}
|
||||
|
||||
/// Visual selection state for text selection.
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct VisualSelection {
|
||||
pub start: Option<(usize, usize)>,
|
||||
pub end: Option<(usize, usize)>,
|
||||
}
|
||||
|
||||
impl VisualSelection {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn start_at(&mut self, pos: (usize, usize)) {
|
||||
self.start = Some(pos);
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn extend_to(&mut self, pos: (usize, usize)) {
|
||||
self.end = Some(pos);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.start = None;
|
||||
self.end = None;
|
||||
}
|
||||
|
||||
pub fn is_active(&self) -> bool {
|
||||
self.start.is_some() && self.end.is_some()
|
||||
}
|
||||
|
||||
pub fn get_normalized(&self) -> Option<((usize, usize), (usize, usize))> {
|
||||
if let (Some(s), Some(e)) = (self.start, self.end) {
|
||||
if s.0 < e.0 || (s.0 == e.0 && s.1 <= e.1) {
|
||||
Some((s, e))
|
||||
} else {
|
||||
Some((e, s))
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cursor position helper for navigating scrollable content.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct CursorPosition {
|
||||
pub row: usize,
|
||||
pub col: usize,
|
||||
}
|
||||
|
||||
impl CursorPosition {
|
||||
pub fn new(row: usize, col: usize) -> Self {
|
||||
Self { row, col }
|
||||
}
|
||||
|
||||
pub fn move_up(&mut self, amount: usize) {
|
||||
self.row = self.row.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_down(&mut self, amount: usize, max: usize) {
|
||||
self.row = (self.row + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn move_left(&mut self, amount: usize) {
|
||||
self.col = self.col.saturating_sub(amount);
|
||||
}
|
||||
|
||||
pub fn move_right(&mut self, amount: usize, max: usize) {
|
||||
self.col = (self.col + amount).min(max);
|
||||
}
|
||||
|
||||
pub fn as_tuple(&self) -> (usize, usize) {
|
||||
(self.row, self.col)
|
||||
}
|
||||
}
|
||||
@@ -1,558 +0,0 @@
|
||||
//! Session persistence and storage management backed by SQLite
|
||||
|
||||
// TODO: Upgrade to generic-array 1.x to remove deprecation warnings
|
||||
#![allow(deprecated)]
|
||||
|
||||
use crate::types::Conversation;
|
||||
use crate::{Error, Result};
|
||||
use aes_gcm::aead::{Aead, KeyInit};
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous};
|
||||
use sqlx::{Pool, Row, Sqlite};
|
||||
use std::fs;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::{self, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::str::FromStr;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Metadata about a saved session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionMeta {
|
||||
/// Conversation ID
|
||||
pub id: Uuid,
|
||||
/// Optional session name
|
||||
pub name: Option<String>,
|
||||
/// Optional AI-generated description
|
||||
pub description: Option<String>,
|
||||
/// Number of messages in the conversation
|
||||
pub message_count: usize,
|
||||
/// Model used
|
||||
pub model: String,
|
||||
/// When the session was created
|
||||
pub created_at: SystemTime,
|
||||
/// When the session was last updated
|
||||
pub updated_at: SystemTime,
|
||||
}
|
||||
|
||||
/// Storage manager for persisting conversations in SQLite
|
||||
pub struct StorageManager {
|
||||
pool: Pool<Sqlite>,
|
||||
database_path: PathBuf,
|
||||
}
|
||||
|
||||
impl StorageManager {
|
||||
/// Create a new storage manager using the default database path
|
||||
pub async fn new() -> Result<Self> {
|
||||
let db_path = Self::default_database_path()?;
|
||||
Self::with_database_path(db_path).await
|
||||
}
|
||||
|
||||
/// Create a storage manager using the provided database path
|
||||
pub async fn with_database_path(database_path: PathBuf) -> Result<Self> {
|
||||
if let Some(parent) = database_path.parent() && !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to create database directory {parent:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
let options = SqliteConnectOptions::from_str(&format!(
|
||||
"sqlite://{}",
|
||||
database_path
|
||||
.to_str()
|
||||
.ok_or_else(|| Error::Storage("Invalid database path".to_string()))?
|
||||
))
|
||||
.map_err(|e| Error::Storage(format!("Invalid database URL: {e}")))?
|
||||
.create_if_missing(true)
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.synchronous(SqliteSynchronous::Normal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect_with(options)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to connect to database: {e}")))?;
|
||||
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to run database migrations: {e}")))?;
|
||||
|
||||
let storage = Self {
|
||||
pool,
|
||||
database_path,
|
||||
};
|
||||
|
||||
storage.try_migrate_legacy_sessions().await?;
|
||||
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
/// Save a conversation. Existing entries are updated in-place.
|
||||
pub async fn save_conversation(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
) -> Result<()> {
|
||||
self.save_conversation_with_description(conversation, name, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Save a conversation with an optional description override
|
||||
pub async fn save_conversation_with_description(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<()> {
|
||||
let mut serialized = conversation.clone();
|
||||
if name.is_some() {
|
||||
serialized.name = name.clone();
|
||||
}
|
||||
if description.is_some() {
|
||||
serialized.description = description.clone();
|
||||
}
|
||||
|
||||
let data = serde_json::to_string(&serialized)
|
||||
.map_err(|e| Error::Storage(format!("Failed to serialize conversation: {e}")))?;
|
||||
|
||||
let created_at = to_epoch_seconds(serialized.created_at);
|
||||
let updated_at = to_epoch_seconds(serialized.updated_at);
|
||||
let message_count = serialized.messages.len() as i64;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO conversations (
|
||||
id,
|
||||
name,
|
||||
description,
|
||||
model,
|
||||
message_count,
|
||||
created_at,
|
||||
updated_at,
|
||||
data
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
description = excluded.description,
|
||||
model = excluded.model,
|
||||
message_count = excluded.message_count,
|
||||
created_at = excluded.created_at,
|
||||
updated_at = excluded.updated_at,
|
||||
data = excluded.data
|
||||
"#,
|
||||
)
|
||||
.bind(serialized.id.to_string())
|
||||
.bind(name.or(serialized.name.clone()))
|
||||
.bind(description.or(serialized.description.clone()))
|
||||
.bind(&serialized.model)
|
||||
.bind(message_count)
|
||||
.bind(created_at)
|
||||
.bind(updated_at)
|
||||
.bind(data)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to save conversation: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a conversation by ID
|
||||
pub async fn load_conversation(&self, id: Uuid) -> Result<Conversation> {
|
||||
let record = sqlx::query(r#"SELECT data FROM conversations WHERE id = ?1"#)
|
||||
.bind(id.to_string())
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to load conversation: {e}")))?;
|
||||
|
||||
let row =
|
||||
record.ok_or_else(|| Error::Storage(format!("No conversation found with id {id}")))?;
|
||||
|
||||
let data: String = row
|
||||
.try_get("data")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read conversation payload: {e}")))?;
|
||||
|
||||
serde_json::from_str(&data)
|
||||
.map_err(|e| Error::Storage(format!("Failed to deserialize conversation: {e}")))
|
||||
}
|
||||
|
||||
/// List metadata for all saved conversations ordered by most recent update
|
||||
pub async fn list_sessions(&self) -> Result<Vec<SessionMeta>> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, name, description, model, message_count, created_at, updated_at
|
||||
FROM conversations
|
||||
ORDER BY updated_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to list sessions: {e}")))?;
|
||||
|
||||
let mut sessions = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let id_text: String = row
|
||||
.try_get("id")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read id column: {e}")))?;
|
||||
let id = Uuid::parse_str(&id_text)
|
||||
.map_err(|e| Error::Storage(format!("Invalid UUID in storage: {e}")))?;
|
||||
|
||||
let message_count: i64 = row
|
||||
.try_get("message_count")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read message count: {e}")))?;
|
||||
|
||||
let created_at: i64 = row
|
||||
.try_get("created_at")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read created_at: {e}")))?;
|
||||
let updated_at: i64 = row
|
||||
.try_get("updated_at")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read updated_at: {e}")))?;
|
||||
|
||||
sessions.push(SessionMeta {
|
||||
id,
|
||||
name: row
|
||||
.try_get("name")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read name: {e}")))?,
|
||||
description: row
|
||||
.try_get("description")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read description: {e}")))?,
|
||||
model: row
|
||||
.try_get("model")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read model: {e}")))?,
|
||||
message_count: message_count as usize,
|
||||
created_at: from_epoch_seconds(created_at),
|
||||
updated_at: from_epoch_seconds(updated_at),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
/// Delete a conversation by ID
|
||||
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
||||
sqlx::query("DELETE FROM conversations WHERE id = ?1")
|
||||
.bind(id.to_string())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to delete conversation: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn store_secure_item(
|
||||
&self,
|
||||
key: &str,
|
||||
plaintext: &[u8],
|
||||
master_key: &[u8],
|
||||
) -> Result<()> {
|
||||
let cipher = create_cipher(master_key)?;
|
||||
let nonce_bytes = generate_nonce()?;
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, plaintext)
|
||||
.map_err(|e| Error::Storage(format!("Failed to encrypt secure item: {e}")))?;
|
||||
|
||||
let now = to_epoch_seconds(SystemTime::now());
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO secure_items (key, nonce, ciphertext, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
nonce = excluded.nonce,
|
||||
ciphertext = excluded.ciphertext,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(key)
|
||||
.bind(&nonce_bytes[..])
|
||||
.bind(&ciphertext[..])
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to store secure item: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn load_secure_item(&self, key: &str, master_key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
let record = sqlx::query("SELECT nonce, ciphertext FROM secure_items WHERE key = ?1")
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to load secure item: {e}")))?;
|
||||
|
||||
let Some(row) = record else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let nonce_bytes: Vec<u8> = row
|
||||
.try_get("nonce")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read secure item nonce: {e}")))?;
|
||||
let ciphertext: Vec<u8> = row
|
||||
.try_get("ciphertext")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read secure item ciphertext: {e}")))?;
|
||||
|
||||
if nonce_bytes.len() != 12 {
|
||||
return Err(Error::Storage(
|
||||
"Invalid nonce length for secure item".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let cipher = create_cipher(master_key)?;
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext.as_ref())
|
||||
.map_err(|e| Error::Storage(format!("Failed to decrypt secure item: {e}")))?;
|
||||
|
||||
Ok(Some(plaintext))
|
||||
}
|
||||
|
||||
pub async fn delete_secure_item(&self, key: &str) -> Result<()> {
|
||||
sqlx::query("DELETE FROM secure_items WHERE key = ?1")
|
||||
.bind(key)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to delete secure item: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn clear_secure_items(&self) -> Result<()> {
|
||||
sqlx::query("DELETE FROM secure_items")
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to clear secure items: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Database location used by this storage manager
|
||||
pub fn database_path(&self) -> &Path {
|
||||
&self.database_path
|
||||
}
|
||||
|
||||
/// Determine default database path (platform specific)
|
||||
pub fn default_database_path() -> Result<PathBuf> {
|
||||
let data_dir = dirs::data_local_dir()
|
||||
.ok_or_else(|| Error::Storage("Could not determine data directory".to_string()))?;
|
||||
Ok(data_dir.join("owlen").join("owlen.db"))
|
||||
}
|
||||
|
||||
fn legacy_sessions_dir() -> Result<PathBuf> {
|
||||
let data_dir = dirs::data_local_dir()
|
||||
.ok_or_else(|| Error::Storage("Could not determine data directory".to_string()))?;
|
||||
Ok(data_dir.join("owlen").join("sessions"))
|
||||
}
|
||||
|
||||
async fn database_has_records(&self) -> Result<bool> {
|
||||
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM conversations")
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to inspect database: {e}")))?;
|
||||
Ok(count > 0)
|
||||
}
|
||||
|
||||
async fn try_migrate_legacy_sessions(&self) -> Result<()> {
|
||||
if self.database_has_records().await? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let legacy_dir = match Self::legacy_sessions_dir() {
|
||||
Ok(dir) => dir,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
|
||||
if !legacy_dir.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let entries = fs::read_dir(&legacy_dir).map_err(|e| {
|
||||
Error::Storage(format!("Failed to read legacy sessions directory: {e}"))
|
||||
})?;
|
||||
|
||||
let mut json_files = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
json_files.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
if json_files.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !io::stdin().is_terminal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!(
|
||||
"Legacy OWLEN session files were found in {}.",
|
||||
legacy_dir.display()
|
||||
);
|
||||
if !prompt_yes_no("Migrate them to the new SQLite storage? (y/N) ")? {
|
||||
println!("Skipping legacy session migration.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("Migrating legacy sessions...");
|
||||
let mut migrated = 0usize;
|
||||
for path in &json_files {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(content) => match serde_json::from_str::<Conversation>(&content) {
|
||||
Ok(conversation) => {
|
||||
if let Err(err) = self
|
||||
.save_conversation_with_description(
|
||||
&conversation,
|
||||
conversation.name.clone(),
|
||||
conversation.description.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
println!(" • Failed to migrate {}: {}", path.display(), err);
|
||||
} else {
|
||||
migrated += 1;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!(
|
||||
" • Failed to parse conversation {}: {}",
|
||||
path.display(),
|
||||
err
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
println!(" • Failed to read {}: {}", path.display(), err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if migrated > 0 && let Err(err) = archive_legacy_directory(&legacy_dir) {
|
||||
println!(
|
||||
"Warning: migrated sessions but failed to archive legacy directory: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
|
||||
println!("Migrated {} legacy sessions.", migrated);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_epoch_seconds(time: SystemTime) -> i64 {
|
||||
match time.duration_since(UNIX_EPOCH) {
|
||||
Ok(duration) => duration.as_secs() as i64,
|
||||
Err(_) => 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_epoch_seconds(seconds: i64) -> SystemTime {
|
||||
UNIX_EPOCH + Duration::from_secs(seconds.max(0) as u64)
|
||||
}
|
||||
|
||||
fn prompt_yes_no(prompt: &str) -> Result<bool> {
|
||||
print!("{}", prompt);
|
||||
io::stdout()
|
||||
.flush()
|
||||
.map_err(|e| Error::Storage(format!("Failed to flush stdout: {e}")))?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.map_err(|e| Error::Storage(format!("Failed to read input: {e}")))?;
|
||||
let trimmed = input.trim().to_lowercase();
|
||||
Ok(matches!(trimmed.as_str(), "y" | "yes"))
|
||||
}
|
||||
|
||||
fn archive_legacy_directory(legacy_dir: &Path) -> Result<()> {
|
||||
let mut backup_dir = legacy_dir.with_file_name("sessions_legacy_backup");
|
||||
let mut counter = 1;
|
||||
while backup_dir.exists() {
|
||||
backup_dir = legacy_dir.with_file_name(format!("sessions_legacy_backup_{}", counter));
|
||||
counter += 1;
|
||||
}
|
||||
|
||||
fs::rename(legacy_dir, &backup_dir).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to archive legacy sessions directory {}: {}",
|
||||
legacy_dir.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
println!("Legacy session files archived to {}", backup_dir.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_cipher(master_key: &[u8]) -> Result<Aes256Gcm> {
|
||||
if master_key.len() != 32 {
|
||||
return Err(Error::Storage(
|
||||
"Master key must be 32 bytes for AES-256-GCM".to_string(),
|
||||
));
|
||||
}
|
||||
Aes256Gcm::new_from_slice(master_key).map_err(|_| {
|
||||
Error::Storage("Failed to initialize cipher with provided master key".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_nonce() -> Result<[u8; 12]> {
|
||||
let mut nonce = [0u8; 12];
|
||||
SystemRandom::new()
|
||||
.fill(&mut nonce)
|
||||
.map_err(|_| Error::Storage("Failed to generate nonce".to_string()))?;
|
||||
Ok(nonce)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::{Conversation, Message};
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn sample_conversation() -> Conversation {
|
||||
Conversation {
|
||||
id: Uuid::new_v4(),
|
||||
name: Some("Test conversation".to_string()),
|
||||
description: Some("A sample conversation".to_string()),
|
||||
messages: vec![
|
||||
Message::user("Hello".to_string()),
|
||||
Message::assistant("Hi".to_string()),
|
||||
],
|
||||
model: "test-model".to_string(),
|
||||
created_at: SystemTime::now(),
|
||||
updated_at: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_storage_lifecycle() {
|
||||
let temp_dir = tempdir().expect("failed to create temp dir");
|
||||
let db_path = temp_dir.path().join("owlen.db");
|
||||
let storage = StorageManager::with_database_path(db_path).await.unwrap();
|
||||
|
||||
let conversation = sample_conversation();
|
||||
storage
|
||||
.save_conversation(&conversation, None)
|
||||
.await
|
||||
.expect("failed to save conversation");
|
||||
|
||||
let sessions = storage.list_sessions().await.unwrap();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0].id, conversation.id);
|
||||
|
||||
let loaded = storage.load_conversation(conversation.id).await.unwrap();
|
||||
assert_eq!(loaded.messages.len(), 2);
|
||||
|
||||
storage
|
||||
.delete_session(conversation.id)
|
||||
.await
|
||||
.expect("failed to delete conversation");
|
||||
let sessions = storage.list_sessions().await.unwrap();
|
||||
assert!(sessions.is_empty());
|
||||
}
|
||||
}
|
||||
@@ -1,151 +0,0 @@
|
||||
//! Tool module aggregating built‑in tool implementations.
|
||||
//!
|
||||
//! The crate originally declared `pub mod tools;` in `lib.rs` but the source
|
||||
//! directory only contained individual tool files without a `mod.rs`, causing the
|
||||
//! compiler to look for `tools.rs` and fail. Adding this module file makes the
|
||||
//! directory a proper Rust module and re‑exports the concrete tool types.
|
||||
|
||||
pub mod code_exec;
|
||||
pub mod fs_tools;
|
||||
pub mod registry;
|
||||
pub mod web_scrape;
|
||||
pub mod web_search;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use once_cell::sync::Lazy;
|
||||
use regex::Regex;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
/// MCP mandates tool identifiers to match `^[A-Za-z0-9_-]{1,64}$`.
|
||||
pub const MAX_TOOL_IDENTIFIER_LEN: usize = 64;
|
||||
|
||||
static TOOL_IDENTIFIER_RE: Lazy<Regex> =
|
||||
Lazy::new(|| Regex::new(r"^[A-Za-z0-9_-]{1,64}$").expect("valid tool identifier regex"));
|
||||
|
||||
pub const WEB_SEARCH_TOOL_NAME: &str = "web_search";
|
||||
|
||||
/// Return the canonical identifier for a tool.
|
||||
pub fn canonical_tool_name(name: &str) -> &str {
|
||||
name
|
||||
}
|
||||
|
||||
/// Check whether two tool identifiers refer to the same logical tool.
|
||||
pub fn tool_name_matches(lhs: &str, rhs: &str) -> bool {
|
||||
canonical_tool_name(lhs) == canonical_tool_name(rhs)
|
||||
}
|
||||
|
||||
/// Determine whether the provided identifier satisfies the MCP naming contract.
|
||||
pub fn is_valid_tool_identifier(name: &str) -> bool {
|
||||
TOOL_IDENTIFIER_RE.is_match(name)
|
||||
}
|
||||
|
||||
/// Provide lint-style feedback when a tool identifier falls outside the MCP rules.
|
||||
pub fn tool_identifier_violation(name: &str) -> Option<String> {
|
||||
if name.is_empty() {
|
||||
return Some("Tool identifiers must not be empty.".to_string());
|
||||
}
|
||||
|
||||
if name.len() > MAX_TOOL_IDENTIFIER_LEN {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' exceeds the {MAX_TOOL_IDENTIFIER_LEN}-character MCP limit."
|
||||
));
|
||||
}
|
||||
|
||||
if name.trim() != name {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' contains leading or trailing whitespace."
|
||||
));
|
||||
}
|
||||
|
||||
if !TOOL_IDENTIFIER_RE.is_match(name) {
|
||||
return Some(format!(
|
||||
"Tool identifier '{name}' may only contain ASCII letters, digits, hyphens, or underscores."
|
||||
));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Trait representing a tool that can be called via the MCP interface.
|
||||
#[async_trait]
|
||||
pub trait Tool: Send + Sync {
|
||||
/// Unique name of the tool (used in the MCP protocol).
|
||||
fn name(&self) -> &'static str;
|
||||
/// Human‑readable description for documentation.
|
||||
fn description(&self) -> &'static str;
|
||||
/// JSON‑Schema describing the expected arguments.
|
||||
fn schema(&self) -> Value;
|
||||
/// Execute the tool with the provided arguments.
|
||||
fn requires_network(&self) -> bool {
|
||||
false
|
||||
}
|
||||
fn requires_filesystem(&self) -> Vec<String> {
|
||||
Vec::new()
|
||||
}
|
||||
/// Optional additional identifiers (must remain spec-compliant).
|
||||
fn aliases(&self) -> &'static [&'static str] {
|
||||
&[]
|
||||
}
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult>;
|
||||
}
|
||||
|
||||
/// Result returned by a tool execution.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
pub struct ToolResult {
|
||||
/// Indicates whether the tool completed successfully.
|
||||
pub success: bool,
|
||||
/// Human‑readable status string – retained for compatibility.
|
||||
pub status: String,
|
||||
/// Arbitrary JSON payload describing the tool output.
|
||||
pub output: Value,
|
||||
/// Execution duration.
|
||||
#[serde(skip_serializing_if = "Duration::is_zero", default)]
|
||||
pub duration: Duration,
|
||||
/// Optional key/value metadata for the tool invocation.
|
||||
#[serde(default)]
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl ToolResult {
|
||||
pub fn success(output: Value) -> Self {
|
||||
Self {
|
||||
success: true,
|
||||
status: "success".into(),
|
||||
output,
|
||||
duration: Duration::default(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn error(msg: &str) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
status: "error".into(),
|
||||
output: json!({ "error": msg }),
|
||||
duration: Duration::default(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cancelled(msg: &str) -> Self {
|
||||
Self {
|
||||
success: false,
|
||||
status: "cancelled".into(),
|
||||
output: json!({ "error": msg }),
|
||||
duration: Duration::default(),
|
||||
metadata: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re‑export the most commonly used types so they can be accessed as
|
||||
// `owlen_core::tools::CodeExecTool`, etc.
|
||||
pub use code_exec::CodeExecTool;
|
||||
pub use fs_tools::{ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool};
|
||||
pub use registry::ToolRegistry;
|
||||
pub use web_scrape::WebScrapeTool;
|
||||
pub use web_search::{WebSearchSettings, WebSearchTool};
|
||||
@@ -1,148 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use crate::Result;
|
||||
use anyhow::{Context, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::sandbox::{SandboxConfig, SandboxedProcess};
|
||||
|
||||
pub struct CodeExecTool {
|
||||
allowed_languages: Arc<Vec<String>>,
|
||||
}
|
||||
|
||||
impl CodeExecTool {
|
||||
pub fn new(allowed_languages: Vec<String>) -> Self {
|
||||
Self {
|
||||
allowed_languages: Arc::new(allowed_languages),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CodeExecTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"code_exec"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Execute code snippets within a sandboxed environment"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"language": {
|
||||
"type": "string",
|
||||
"enum": self.allowed_languages.as_slice(),
|
||||
"description": "Language of the code block"
|
||||
},
|
||||
"code": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 10000,
|
||||
"description": "Code to execute"
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"default": 30,
|
||||
"description": "Execution timeout in seconds"
|
||||
}
|
||||
},
|
||||
"required": ["language", "code"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
let language = args
|
||||
.get("language")
|
||||
.and_then(Value::as_str)
|
||||
.context("Missing language parameter")?;
|
||||
let code = args
|
||||
.get("code")
|
||||
.and_then(Value::as_str)
|
||||
.context("Missing code parameter")?;
|
||||
let timeout = args.get("timeout").and_then(Value::as_u64).unwrap_or(30);
|
||||
|
||||
if !self.allowed_languages.iter().any(|lang| lang == language) {
|
||||
return Err(anyhow!("Language '{}' not permitted", language).into());
|
||||
}
|
||||
|
||||
let (command, command_args) = match language {
|
||||
"python" => (
|
||||
"python3".to_string(),
|
||||
vec!["-c".to_string(), code.to_string()],
|
||||
),
|
||||
"javascript" => ("node".to_string(), vec!["-e".to_string(), code.to_string()]),
|
||||
"bash" => ("bash".to_string(), vec!["-c".to_string(), code.to_string()]),
|
||||
"rust" => {
|
||||
let mut result =
|
||||
ToolResult::error("Rust execution is not yet supported in the sandbox");
|
||||
result.duration = start.elapsed();
|
||||
return Ok(result);
|
||||
}
|
||||
other => return Err(anyhow!("Unsupported language: {}", other).into()),
|
||||
};
|
||||
|
||||
let sandbox_config = SandboxConfig {
|
||||
allow_network: false,
|
||||
timeout_seconds: timeout,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let sandbox_result = tokio::task::spawn_blocking(move || {
|
||||
let sandbox = SandboxedProcess::new(sandbox_config)?;
|
||||
let arg_refs: Vec<&str> = command_args.iter().map(|s| s.as_str()).collect();
|
||||
sandbox.execute(&command, &arg_refs)
|
||||
})
|
||||
.await
|
||||
.context("Sandbox execution task failed")??;
|
||||
|
||||
let mut result = if sandbox_result.exit_code == 0 {
|
||||
ToolResult::success(json!({
|
||||
"stdout": sandbox_result.stdout,
|
||||
"stderr": sandbox_result.stderr,
|
||||
"exit_code": sandbox_result.exit_code,
|
||||
"timed_out": sandbox_result.was_timeout,
|
||||
}))
|
||||
} else {
|
||||
let error_msg = if sandbox_result.was_timeout {
|
||||
format!(
|
||||
"Execution timed out after {} seconds (exit code {}): {}",
|
||||
timeout, sandbox_result.exit_code, sandbox_result.stderr
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"Execution failed with status {}: {}",
|
||||
sandbox_result.exit_code, sandbox_result.stderr
|
||||
)
|
||||
};
|
||||
let mut err_result = ToolResult::error(&error_msg);
|
||||
err_result.output = json!({
|
||||
"stdout": sandbox_result.stdout,
|
||||
"stderr": sandbox_result.stderr,
|
||||
"exit_code": sandbox_result.exit_code,
|
||||
"timed_out": sandbox_result.was_timeout,
|
||||
});
|
||||
err_result
|
||||
};
|
||||
|
||||
result.duration = start.elapsed();
|
||||
result
|
||||
.metadata
|
||||
.insert("language".to_string(), language.to_string());
|
||||
result
|
||||
.metadata
|
||||
.insert("timeout_seconds".to_string(), timeout.to_string());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -1,198 +0,0 @@
|
||||
use crate::tools::{Tool, ToolResult};
|
||||
use crate::{Error, Result};
|
||||
use async_trait::async_trait;
|
||||
use path_clean::PathClean;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct FileArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
fn sanitize_path(path: &str, root: &Path) -> Result<PathBuf> {
|
||||
let path = Path::new(path);
|
||||
let path = if path.is_absolute() {
|
||||
// Strip leading '/' to treat as relative to the project root.
|
||||
path.strip_prefix("/")
|
||||
.map_err(|_| Error::InvalidInput("Invalid path".into()))?
|
||||
.to_path_buf()
|
||||
} else {
|
||||
path.to_path_buf()
|
||||
};
|
||||
|
||||
let full_path = root.join(path).clean();
|
||||
|
||||
if !full_path.starts_with(root) {
|
||||
return Err(Error::PermissionDenied("Path traversal detected".into()));
|
||||
}
|
||||
|
||||
Ok(full_path)
|
||||
}
|
||||
|
||||
pub struct ResourcesListTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ResourcesListTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"resources_list"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Lists directory contents."
|
||||
}
|
||||
|
||||
fn schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the directory to list."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let args: FileArgs = serde_json::from_value(args)?;
|
||||
let root = env::current_dir()?;
|
||||
let full_path = sanitize_path(&args.path, &root)?;
|
||||
|
||||
let entries = fs::read_dir(full_path)?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
result.push(entry.file_name().to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
Ok(ToolResult::success(serde_json::to_value(result)?))
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ResourcesGetTool;
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ResourcesGetTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"resources_get"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Reads file content."
|
||||
}
|
||||
|
||||
fn schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "The path to the file to read."
|
||||
}
|
||||
},
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let args: FileArgs = serde_json::from_value(args)?;
|
||||
let root = env::current_dir()?;
|
||||
let full_path = sanitize_path(&args.path, &root)?;
|
||||
|
||||
let content = fs::read_to_string(full_path)?;
|
||||
|
||||
Ok(ToolResult::success(serde_json::to_value(content)?))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Write tool – writes (or overwrites) a file under the project root.
|
||||
// ---------------------------------------------------------------------------
|
||||
pub struct ResourcesWriteTool;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct WriteArgs {
|
||||
path: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ResourcesWriteTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"resources_write"
|
||||
}
|
||||
fn description(&self) -> &'static str {
|
||||
"Writes (or overwrites) a file. Requires explicit consent."
|
||||
}
|
||||
fn schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": { "type": "string", "description": "Target file path (relative to project root)" },
|
||||
"content": { "type": "string", "description": "File content to write" }
|
||||
},
|
||||
"required": ["path", "content"]
|
||||
})
|
||||
}
|
||||
fn requires_filesystem(&self) -> Vec<String> {
|
||||
vec!["file_write".to_string()]
|
||||
}
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let args: WriteArgs = serde_json::from_value(args)?;
|
||||
let root = env::current_dir()?;
|
||||
let full_path = sanitize_path(&args.path, &root)?;
|
||||
// Ensure the parent directory exists
|
||||
if let Some(parent) = full_path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
fs::write(full_path, args.content)?;
|
||||
Ok(ToolResult::success(json!(null)))
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Delete tool – deletes a file under the project root.
|
||||
// ---------------------------------------------------------------------------
|
||||
pub struct ResourcesDeleteTool;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct DeleteArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for ResourcesDeleteTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"resources_delete"
|
||||
}
|
||||
fn description(&self) -> &'static str {
|
||||
"Deletes a file. Requires explicit consent."
|
||||
}
|
||||
fn schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": { "path": { "type": "string", "description": "File path to delete" } },
|
||||
"required": ["path"]
|
||||
})
|
||||
}
|
||||
fn requires_filesystem(&self) -> Vec<String> {
|
||||
vec!["file_delete".to_string()]
|
||||
}
|
||||
async fn execute(&self, args: serde_json::Value) -> Result<ToolResult> {
|
||||
let args: DeleteArgs = serde_json::from_value(args)?;
|
||||
let root = env::current_dir()?;
|
||||
let full_path = sanitize_path(&args.path, &root)?;
|
||||
if full_path.is_file() {
|
||||
fs::remove_file(full_path)?;
|
||||
Ok(ToolResult::success(json!(null)))
|
||||
} else {
|
||||
Err(Error::InvalidInput("Path does not refer to a file".into()))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,206 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{Error, Result};
|
||||
use anyhow::Context;
|
||||
use serde_json::Value;
|
||||
|
||||
use super::{
|
||||
Tool, ToolResult, WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_identifier_violation,
|
||||
};
|
||||
use crate::config::Config;
|
||||
use crate::mode::Mode;
|
||||
use crate::ui::UiController;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
tools: HashMap<String, Arc<dyn Tool>>,
|
||||
config: Arc<tokio::sync::Mutex<Config>>,
|
||||
ui: Arc<dyn UiController>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(config: Arc<tokio::sync::Mutex<Config>>, ui: Arc<dyn UiController>) -> Self {
|
||||
Self {
|
||||
tools: HashMap::new(),
|
||||
config,
|
||||
ui,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register<T>(&mut self, tool: T) -> Result<()>
|
||||
where
|
||||
T: Tool + 'static,
|
||||
{
|
||||
let tool: Arc<dyn Tool> = Arc::new(tool);
|
||||
let name = tool.name();
|
||||
|
||||
if let Some(reason) = tool_identifier_violation(name) {
|
||||
log::error!("Tool '{}' failed validation: {}", name, reason);
|
||||
return Err(Error::InvalidInput(format!(
|
||||
"Tool '{name}' is not a valid MCP identifier: {reason}"
|
||||
)));
|
||||
}
|
||||
|
||||
if self
|
||||
.tools
|
||||
.insert(name.to_string(), Arc::clone(&tool))
|
||||
.is_some()
|
||||
{
|
||||
log::warn!(
|
||||
"Tool '{}' was already registered; overwriting previous entry.",
|
||||
name
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
|
||||
self.tools.get(name).cloned()
|
||||
}
|
||||
|
||||
pub fn all(&self) -> Vec<Arc<dyn Tool>> {
|
||||
self.tools.values().cloned().collect()
|
||||
}
|
||||
|
||||
pub async fn execute(&self, name: &str, args: Value, mode: Mode) -> Result<ToolResult> {
|
||||
let canonical = canonical_tool_name(name);
|
||||
let tool = self
|
||||
.get(canonical)
|
||||
.with_context(|| format!("Tool not registered: {}", name))?;
|
||||
|
||||
let mut config = self.config.lock().await;
|
||||
|
||||
// Check mode-based tool availability first
|
||||
if !(config.modes.is_tool_allowed(mode, canonical)
|
||||
|| config.modes.is_tool_allowed(mode, name))
|
||||
{
|
||||
let alternate_mode = match mode {
|
||||
Mode::Chat => Mode::Code,
|
||||
Mode::Code => Mode::Chat,
|
||||
};
|
||||
|
||||
if config.modes.is_tool_allowed(alternate_mode, canonical)
|
||||
|| config.modes.is_tool_allowed(alternate_mode, name)
|
||||
{
|
||||
return Ok(ToolResult::error(&format!(
|
||||
"Tool '{}' is not available in {} mode. Switch to {} mode to use this tool (use :mode {} command).",
|
||||
name, mode, alternate_mode, alternate_mode
|
||||
)));
|
||||
} else {
|
||||
return Ok(ToolResult::error(&format!(
|
||||
"Tool '{}' is not available in any mode. Check your configuration.",
|
||||
name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let is_enabled = match canonical {
|
||||
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled,
|
||||
"code_exec" => config.tools.code_exec.enabled,
|
||||
_ => true, // All other tools are considered enabled by default
|
||||
};
|
||||
|
||||
if !is_enabled {
|
||||
let prompt = format!(
|
||||
"Tool '{}' is disabled. Would you like to enable it for this session?",
|
||||
name
|
||||
);
|
||||
if self.ui.confirm(&prompt).await {
|
||||
// Enable the tool in the in-memory config for the current session
|
||||
match canonical {
|
||||
WEB_SEARCH_TOOL_NAME => config.tools.web_search.enabled = true,
|
||||
"code_exec" => config.tools.code_exec.enabled = true,
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
return Ok(ToolResult::cancelled(&format!(
|
||||
"Tool '{}' execution was cancelled by the user.",
|
||||
name
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
tool.execute(args).await
|
||||
}
|
||||
|
||||
/// Get all tools available in the given mode
|
||||
pub async fn available_tools(&self, mode: Mode) -> Vec<String> {
|
||||
let config = self.config.lock().await;
|
||||
self.tools
|
||||
.keys()
|
||||
.filter(|name| config.modes.is_tool_allowed(mode, name))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn tools(&self) -> Vec<String> {
|
||||
self.tools.keys().cloned().collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::Config;
|
||||
use crate::tools::{Tool, ToolResult, WEB_SEARCH_TOOL_NAME};
|
||||
use crate::ui::NoOpUiController;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{Value, json};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct DummyTool {
|
||||
name: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for DummyTool {
|
||||
fn name(&self) -> &'static str {
|
||||
self.name
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"dummy tool"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({ "type": "object" })
|
||||
}
|
||||
|
||||
fn aliases(&self) -> &'static [&'static str] {
|
||||
&[]
|
||||
}
|
||||
|
||||
async fn execute(&self, _args: Value) -> Result<ToolResult> {
|
||||
Ok(ToolResult::success(json!({ "echo": true })))
|
||||
}
|
||||
}
|
||||
|
||||
fn registry() -> ToolRegistry {
|
||||
let config = Arc::new(tokio::sync::Mutex::new(Config::default()));
|
||||
let ui = Arc::new(NoOpUiController);
|
||||
ToolRegistry::new(config, ui)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_invalid_tool_identifier() {
|
||||
let mut registry = registry();
|
||||
let tool = DummyTool {
|
||||
name: "invalid.tool",
|
||||
};
|
||||
|
||||
let err = registry.register(tool).unwrap_err();
|
||||
assert!(matches!(err, Error::InvalidInput(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn registers_spec_compliant_tool() {
|
||||
let mut registry = registry();
|
||||
let tool = DummyTool {
|
||||
name: WEB_SEARCH_TOOL_NAME,
|
||||
};
|
||||
|
||||
registry.register(tool).unwrap();
|
||||
assert!(registry.get(WEB_SEARCH_TOOL_NAME).is_some());
|
||||
}
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::Result;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::{Value, json};
|
||||
|
||||
/// Tool that fetches the raw HTML content for a list of URLs.
|
||||
///
|
||||
/// Input schema expects:
|
||||
/// urls: array of strings (max 5 URLs)
|
||||
/// timeout_secs: optional integer per‑request timeout (default 10)
|
||||
pub struct WebScrapeTool {
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl Default for WebScrapeTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl WebScrapeTool {
|
||||
pub fn new() -> Self {
|
||||
let client = Client::builder()
|
||||
.user_agent("OwlenWebScrape/0.1")
|
||||
.build()
|
||||
.expect("Failed to build reqwest client");
|
||||
Self { client }
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WebScrapeTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"web_scrape"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Fetch raw HTML content for a list of URLs"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"urls": {
|
||||
"type": "array",
|
||||
"items": { "type": "string", "format": "uri" },
|
||||
"minItems": 1,
|
||||
"maxItems": 5,
|
||||
"description": "List of URLs to scrape"
|
||||
},
|
||||
"timeout_secs": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 30,
|
||||
"default": 10,
|
||||
"description": "Per‑request timeout in seconds"
|
||||
}
|
||||
},
|
||||
"required": ["urls"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_network(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let urls = args
|
||||
.get("urls")
|
||||
.and_then(|v| v.as_array())
|
||||
.context("Missing 'urls' array")?;
|
||||
let timeout_secs = args
|
||||
.get("timeout_secs")
|
||||
.and_then(|v| v.as_u64())
|
||||
.unwrap_or(10);
|
||||
|
||||
let mut results = Vec::new();
|
||||
for url_val in urls {
|
||||
let url = url_val.as_str().unwrap_or("");
|
||||
let resp = self
|
||||
.client
|
||||
.get(url)
|
||||
.timeout(std::time::Duration::from_secs(timeout_secs))
|
||||
.send()
|
||||
.await;
|
||||
match resp {
|
||||
Ok(r) => {
|
||||
let text = r.text().await.unwrap_or_default();
|
||||
results.push(json!({ "url": url, "content": text }));
|
||||
}
|
||||
Err(e) => {
|
||||
results.push(json!({ "url": url, "error": e.to_string() }));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(ToolResult::success(json!({ "pages": results })))
|
||||
}
|
||||
}
|
||||
@@ -1,165 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use crate::Result;
|
||||
use anyhow::{Context, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::{Client, StatusCode, Url};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use super::{Tool, ToolResult};
|
||||
use crate::consent::ConsentManager;
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
/// Configuration applied to the web search tool at registration time.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WebSearchSettings {
|
||||
pub endpoint: Url,
|
||||
pub api_key: String,
|
||||
pub provider_label: String,
|
||||
pub timeout: Duration,
|
||||
}
|
||||
|
||||
pub struct WebSearchTool {
|
||||
consent_manager: Arc<Mutex<ConsentManager>>,
|
||||
client: Client,
|
||||
settings: WebSearchSettings,
|
||||
}
|
||||
|
||||
impl WebSearchTool {
|
||||
pub fn new(consent_manager: Arc<Mutex<ConsentManager>>, settings: WebSearchSettings) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(settings.timeout)
|
||||
.build()
|
||||
.expect("failed to construct reqwest client for web search");
|
||||
|
||||
Self {
|
||||
consent_manager,
|
||||
client,
|
||||
settings,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for WebSearchTool {
|
||||
fn name(&self) -> &'static str {
|
||||
WEB_SEARCH_TOOL_NAME
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Search the web using the active cloud provider."
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 500,
|
||||
"description": "Search query text"
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
"default": 5,
|
||||
"description": "Maximum number of search results to retrieve"
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
})
|
||||
}
|
||||
|
||||
fn requires_network(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let start = Instant::now();
|
||||
|
||||
{
|
||||
let consent = self
|
||||
.consent_manager
|
||||
.lock()
|
||||
.expect("Consent manager mutex poisoned");
|
||||
|
||||
if !consent.has_consent(self.name()) {
|
||||
return Ok(ToolResult::error(
|
||||
"Consent not granted for web search. Enable the tool from the UI before invoking it.",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let query = args
|
||||
.get("query")
|
||||
.and_then(Value::as_str)
|
||||
.map(str::trim)
|
||||
.filter(|q| !q.is_empty())
|
||||
.ok_or_else(|| anyhow!("Missing query parameter"))?;
|
||||
|
||||
let max_results = args.get("max_results").and_then(Value::as_u64).unwrap_or(5) as u32;
|
||||
|
||||
let payload = json!({
|
||||
"query": query,
|
||||
"max_results": max_results
|
||||
});
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(self.settings.endpoint.clone())
|
||||
.bearer_auth(&self.settings.api_key)
|
||||
.json(&payload)
|
||||
.send()
|
||||
.await
|
||||
.context("Web search request failed")?;
|
||||
|
||||
match response.status() {
|
||||
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
|
||||
return Ok(ToolResult::error(
|
||||
"Cloud web search request was not authorized. Verify your Ollama Cloud API key.",
|
||||
));
|
||||
}
|
||||
StatusCode::TOO_MANY_REQUESTS => {
|
||||
return Ok(ToolResult::error(
|
||||
"Cloud web search is rate limited. Please wait before retrying.",
|
||||
));
|
||||
}
|
||||
status if !status.is_success() => {
|
||||
return Ok(ToolResult::error(&format!(
|
||||
"Cloud web search failed with status {}",
|
||||
status
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let body: Value = response
|
||||
.json()
|
||||
.await
|
||||
.context("Failed to decode cloud search response")?;
|
||||
|
||||
let results = body
|
||||
.get("results")
|
||||
.and_then(|value| value.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_else(Vec::new);
|
||||
|
||||
let mut metadata = HashMap::new();
|
||||
metadata.insert("provider".to_string(), self.settings.provider_label.clone());
|
||||
|
||||
let mut result = ToolResult::success(json!({
|
||||
"query": query,
|
||||
"provider": self.settings.provider_label,
|
||||
"results": results,
|
||||
}));
|
||||
result.duration = start.elapsed();
|
||||
result.metadata = metadata;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
@@ -1,364 +0,0 @@
|
||||
//! Core types used across OWLEN
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::path::PathBuf;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// A message in a conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Message {
|
||||
/// Unique identifier for this message
|
||||
pub id: Uuid,
|
||||
/// Role of the message sender (user, assistant, system)
|
||||
pub role: Role,
|
||||
/// Content of the message
|
||||
pub content: String,
|
||||
/// Optional metadata
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
/// Timestamp when the message was created
|
||||
pub timestamp: std::time::SystemTime,
|
||||
/// Tool calls requested by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
/// Rich attachments (images, artifacts, files) associated with the message
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub attachments: Vec<MessageAttachment>,
|
||||
}
|
||||
|
||||
/// Role of a message sender
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
/// Message from the user
|
||||
User,
|
||||
/// Message from the AI assistant
|
||||
Assistant,
|
||||
/// System message (prompts, context, etc.)
|
||||
System,
|
||||
/// Tool response message
|
||||
Tool,
|
||||
}
|
||||
|
||||
/// A tool call requested by the assistant
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
/// Name of the tool to call
|
||||
pub name: String,
|
||||
/// Arguments for the tool (JSON object)
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
fn default_mime_type() -> String {
|
||||
"application/octet-stream".to_string()
|
||||
}
|
||||
|
||||
/// Attachment associated with a message (image, artifact, or rich output).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct MessageAttachment {
|
||||
/// Unique identifier for this attachment instance.
|
||||
pub id: Uuid,
|
||||
/// Human friendly name, typically a filename.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
/// Optional descriptive text supplied by the sender.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
/// MIME type describing the payload.
|
||||
#[serde(default = "default_mime_type")]
|
||||
pub mime_type: String,
|
||||
/// Source filesystem path if the attachment originated from disk.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub source_path: Option<PathBuf>,
|
||||
/// Binary payload encoded as base64, when applicable.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub data_base64: Option<String>,
|
||||
/// Inline UTF-8 payload when the attachment is textual.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub text_content: Option<String>,
|
||||
/// Approximate size in bytes for UI hints.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub size_bytes: Option<u64>,
|
||||
/// Optional pre-rendered preview lines for fast UI rendering.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub preview_lines: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl MessageAttachment {
|
||||
/// Build an attachment from base64 encoded binary data.
|
||||
pub fn from_base64(
|
||||
name: impl Into<String>,
|
||||
mime_type: impl Into<String>,
|
||||
data_base64: String,
|
||||
size_bytes: Option<u64>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
name: Some(name.into()),
|
||||
description: None,
|
||||
mime_type: mime_type.into(),
|
||||
source_path: None,
|
||||
data_base64: Some(data_base64),
|
||||
text_content: None,
|
||||
size_bytes,
|
||||
preview_lines: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build an attachment from UTF-8 text content.
|
||||
pub fn from_text(name: Option<String>, mime_type: impl Into<String>, text: String) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
name,
|
||||
description: None,
|
||||
mime_type: mime_type.into(),
|
||||
source_path: None,
|
||||
data_base64: None,
|
||||
text_content: Some(text),
|
||||
size_bytes: None,
|
||||
preview_lines: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Attach a source path reference to the attachment.
|
||||
pub fn with_source_path(mut self, path: PathBuf) -> Self {
|
||||
self.source_path = Some(path);
|
||||
self
|
||||
}
|
||||
|
||||
/// Set the description metadata for the attachment.
|
||||
pub fn with_description(mut self, description: impl Into<String>) -> Self {
|
||||
self.description = Some(description.into());
|
||||
self
|
||||
}
|
||||
|
||||
/// Provide pre-rendered preview lines for rapid UI display.
|
||||
pub fn with_preview_lines(mut self, lines: Vec<String>) -> Self {
|
||||
if lines.is_empty() {
|
||||
self.preview_lines = None;
|
||||
} else {
|
||||
self.preview_lines = Some(lines);
|
||||
}
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns true if the attachment MIME type indicates an image.
|
||||
pub fn is_image(&self) -> bool {
|
||||
self.mime_type.to_ascii_lowercase().starts_with("image/")
|
||||
}
|
||||
|
||||
/// Accessor for base64 data payloads.
|
||||
pub fn base64_data(&self) -> Option<&str> {
|
||||
self.data_base64.as_deref()
|
||||
}
|
||||
|
||||
/// Accessor for inline text payloads.
|
||||
pub fn text_data(&self) -> Option<&str> {
|
||||
self.text_content.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for Role {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::System => "system",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// A conversation containing multiple messages
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Conversation {
|
||||
/// Unique identifier for this conversation
|
||||
pub id: Uuid,
|
||||
/// Optional name/title for the conversation
|
||||
pub name: Option<String>,
|
||||
/// Optional AI-generated description of the conversation
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
/// Messages in chronological order
|
||||
pub messages: Vec<Message>,
|
||||
/// Model used for this conversation
|
||||
pub model: String,
|
||||
/// When the conversation was created
|
||||
pub created_at: std::time::SystemTime,
|
||||
/// When the conversation was last updated
|
||||
pub updated_at: std::time::SystemTime,
|
||||
}
|
||||
|
||||
/// Configuration for a chat completion request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
/// The model to use for completion
|
||||
pub model: String,
|
||||
/// The conversation messages
|
||||
pub messages: Vec<Message>,
|
||||
/// Optional parameters for the request
|
||||
pub parameters: ChatParameters,
|
||||
/// Optional tools available for the model to use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<crate::mcp::McpToolDescriptor>>,
|
||||
}
|
||||
|
||||
/// Parameters for chat completion
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ChatParameters {
|
||||
/// Temperature for randomness (0.0 to 2.0)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
/// Maximum tokens to generate
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
/// Whether to stream the response
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
/// Additional provider-specific parameters
|
||||
#[serde(flatten)]
|
||||
#[serde(default)]
|
||||
pub extra: HashMap<String, serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Response from a chat completion request
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatResponse {
|
||||
/// The generated message
|
||||
pub message: Message,
|
||||
/// Token usage information
|
||||
pub usage: Option<TokenUsage>,
|
||||
/// Whether this is a streaming chunk
|
||||
#[serde(default)]
|
||||
pub is_streaming: bool,
|
||||
/// Whether this is the final chunk in a stream
|
||||
#[serde(default)]
|
||||
pub is_final: bool,
|
||||
}
|
||||
|
||||
/// Token usage information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenUsage {
|
||||
/// Tokens in the prompt
|
||||
pub prompt_tokens: u32,
|
||||
/// Tokens in the completion
|
||||
pub completion_tokens: u32,
|
||||
/// Total tokens used
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
/// Information about an available model
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model identifier
|
||||
pub id: String,
|
||||
/// Human-readable name
|
||||
pub name: String,
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
/// Provider that hosts this model
|
||||
pub provider: String,
|
||||
/// Context window size
|
||||
pub context_window: Option<u32>,
|
||||
/// Additional capabilities
|
||||
pub capabilities: Vec<String>,
|
||||
/// Whether this model supports tool/function calling
|
||||
#[serde(default)]
|
||||
pub supports_tools: bool,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
/// Create a new message
|
||||
pub fn new(role: Role, content: String) -> Self {
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
role,
|
||||
content,
|
||||
metadata: HashMap::new(),
|
||||
timestamp: std::time::SystemTime::now(),
|
||||
tool_calls: None,
|
||||
attachments: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: String) -> Self {
|
||||
Self::new(Role::User, content)
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: String) -> Self {
|
||||
Self::new(Role::Assistant, content)
|
||||
}
|
||||
|
||||
/// Create a system message
|
||||
pub fn system(content: String) -> Self {
|
||||
Self::new(Role::System, content)
|
||||
}
|
||||
|
||||
/// Create a tool response message
|
||||
pub fn tool(tool_call_id: String, content: String) -> Self {
|
||||
let mut msg = Self::new(Role::Tool, content);
|
||||
msg.metadata.insert(
|
||||
"tool_call_id".to_string(),
|
||||
serde_json::Value::String(tool_call_id),
|
||||
);
|
||||
msg
|
||||
}
|
||||
|
||||
/// Check if this message has tool calls
|
||||
pub fn has_tool_calls(&self) -> bool {
|
||||
self.tool_calls
|
||||
.as_ref()
|
||||
.map(|tc| !tc.is_empty())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Attach rich artifacts to the message.
|
||||
pub fn with_attachments(mut self, attachments: Vec<MessageAttachment>) -> Self {
|
||||
self.attachments = attachments;
|
||||
self
|
||||
}
|
||||
|
||||
/// Return true when the message carries any attachments.
|
||||
pub fn has_attachments(&self) -> bool {
|
||||
!self.attachments.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
/// Create a new conversation
|
||||
pub fn new(model: String) -> Self {
|
||||
let now = std::time::SystemTime::now();
|
||||
Self {
|
||||
id: Uuid::new_v4(),
|
||||
name: None,
|
||||
description: None,
|
||||
messages: Vec::new(),
|
||||
model,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a message to the conversation
|
||||
pub fn add_message(&mut self, message: Message) {
|
||||
self.messages.push(message);
|
||||
self.updated_at = std::time::SystemTime::now();
|
||||
}
|
||||
|
||||
/// Get the last message in the conversation
|
||||
pub fn last_message(&self) -> Option<&Message> {
|
||||
self.messages.last()
|
||||
}
|
||||
|
||||
/// Clear all messages
|
||||
pub fn clear(&mut self) {
|
||||
self.messages.clear();
|
||||
self.updated_at = std::time::SystemTime::now();
|
||||
}
|
||||
}
|
||||
@@ -1,280 +0,0 @@
|
||||
//! Shared UI components and state management for TUI applications
|
||||
//!
|
||||
//! This module contains reusable UI components that can be shared between
|
||||
//! different TUI applications (chat, code, etc.)
|
||||
|
||||
/// Application state
|
||||
pub use crate::state::AppState;
|
||||
|
||||
/// Input modes for TUI applications
|
||||
pub use crate::state::InputMode;
|
||||
|
||||
/// Represents which panel is currently focused
|
||||
pub use crate::state::FocusedPanel;
|
||||
|
||||
/// Auto-scroll state manager for scrollable panels
|
||||
pub use crate::state::AutoScroll;
|
||||
|
||||
/// Visual selection state for text selection
|
||||
pub use crate::state::VisualSelection;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// How role labels should be rendered alongside chat messages.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum RoleLabelDisplay {
|
||||
Inline,
|
||||
Above,
|
||||
None,
|
||||
}
|
||||
|
||||
/// Extract text from a selection range in a list of lines
|
||||
pub fn extract_text_from_selection(
|
||||
lines: &[String],
|
||||
start: (usize, usize),
|
||||
end: (usize, usize),
|
||||
) -> Option<String> {
|
||||
if lines.is_empty() || start.0 >= lines.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let start_row = start.0;
|
||||
let start_col = start.1;
|
||||
let end_row = end.0.min(lines.len() - 1);
|
||||
let end_col = end.1;
|
||||
|
||||
if start_row == end_row {
|
||||
// Single line selection
|
||||
let line = &lines[start_row];
|
||||
let chars: Vec<char> = line.chars().collect();
|
||||
let start_c = start_col.min(chars.len());
|
||||
let end_c = end_col.min(chars.len());
|
||||
|
||||
if start_c >= end_c {
|
||||
return None;
|
||||
}
|
||||
|
||||
let selected: String = chars[start_c..end_c].iter().collect();
|
||||
Some(selected)
|
||||
} else {
|
||||
// Multi-line selection
|
||||
let mut result = Vec::new();
|
||||
|
||||
// First line: from start_col to end
|
||||
let first_line = &lines[start_row];
|
||||
let first_chars: Vec<char> = first_line.chars().collect();
|
||||
let start_c = start_col.min(first_chars.len());
|
||||
if start_c < first_chars.len() {
|
||||
result.push(first_chars[start_c..].iter().collect::<String>());
|
||||
}
|
||||
|
||||
// Middle lines: entire lines
|
||||
for row in (start_row + 1)..end_row {
|
||||
if row < lines.len() {
|
||||
result.push(lines[row].clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Last line: from start to end_col
|
||||
if end_row < lines.len() && end_row > start_row {
|
||||
let last_line = &lines[end_row];
|
||||
let last_chars: Vec<char> = last_line.chars().collect();
|
||||
let end_c = end_col.min(last_chars.len());
|
||||
if end_c > 0 {
|
||||
result.push(last_chars[..end_c].iter().collect::<String>());
|
||||
}
|
||||
}
|
||||
|
||||
if result.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(result.join("\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Cursor position for navigating scrollable content
|
||||
pub use crate::state::CursorPosition;
|
||||
|
||||
/// Word boundary detection for navigation
|
||||
pub fn find_next_word_boundary(line: &str, col: usize) -> Option<usize> {
|
||||
let chars: Vec<char> = line.chars().collect();
|
||||
|
||||
if col >= chars.len() {
|
||||
return Some(chars.len());
|
||||
}
|
||||
|
||||
let mut pos = col;
|
||||
let is_word_char = |c: char| c.is_alphanumeric() || c == '_';
|
||||
|
||||
// Skip current word
|
||||
if is_word_char(chars[pos]) {
|
||||
while pos < chars.len() && is_word_char(chars[pos]) {
|
||||
pos += 1;
|
||||
}
|
||||
} else {
|
||||
// Skip non-word characters
|
||||
while pos < chars.len() && !is_word_char(chars[pos]) {
|
||||
pos += 1;
|
||||
}
|
||||
}
|
||||
|
||||
Some(pos)
|
||||
}
|
||||
|
||||
pub fn find_word_end(line: &str, col: usize) -> Option<usize> {
|
||||
let chars: Vec<char> = line.chars().collect();
|
||||
|
||||
if col >= chars.len() {
|
||||
return Some(chars.len());
|
||||
}
|
||||
|
||||
let mut pos = col;
|
||||
let is_word_char = |c: char| c.is_alphanumeric() || c == '_';
|
||||
|
||||
// If on a word character, move to end of current word
|
||||
if is_word_char(chars[pos]) {
|
||||
while pos < chars.len() && is_word_char(chars[pos]) {
|
||||
pos += 1;
|
||||
}
|
||||
// Move back one to be ON the last character
|
||||
pos = pos.saturating_sub(1);
|
||||
} else {
|
||||
// Skip non-word characters
|
||||
while pos < chars.len() && !is_word_char(chars[pos]) {
|
||||
pos += 1;
|
||||
}
|
||||
// Now on first char of next word, move to its end
|
||||
while pos < chars.len() && is_word_char(chars[pos]) {
|
||||
pos += 1;
|
||||
}
|
||||
pos = pos.saturating_sub(1);
|
||||
}
|
||||
|
||||
Some(pos)
|
||||
}
|
||||
|
||||
pub fn find_prev_word_boundary(line: &str, col: usize) -> Option<usize> {
|
||||
let chars: Vec<char> = line.chars().collect();
|
||||
|
||||
if col == 0 || chars.is_empty() {
|
||||
return Some(0);
|
||||
}
|
||||
|
||||
let mut pos = col.min(chars.len());
|
||||
let is_word_char = |c: char| c.is_alphanumeric() || c == '_';
|
||||
|
||||
// Move back one position first
|
||||
pos = pos.saturating_sub(1);
|
||||
|
||||
// Skip non-word characters
|
||||
while pos > 0 && !is_word_char(chars[pos]) {
|
||||
pos -= 1;
|
||||
}
|
||||
|
||||
// Skip word characters to find start of word
|
||||
while pos > 0 && is_word_char(chars[pos - 1]) {
|
||||
pos -= 1;
|
||||
}
|
||||
|
||||
Some(pos)
|
||||
}
|
||||
|
||||
use async_trait::async_trait;
|
||||
use owlen_ui_common::Theme;
|
||||
|
||||
pub fn apply_theme_to_string(s: &str, _theme: &Theme) -> String {
|
||||
// This is a placeholder. In a real implementation, you'd parse the string
|
||||
// and apply colors based on syntax or other rules.
|
||||
s.to_string()
|
||||
}
|
||||
|
||||
/// A trait for abstracting UI interactions like confirmations.
|
||||
#[async_trait]
|
||||
pub trait UiController: Send + Sync {
|
||||
async fn confirm(&self, prompt: &str) -> bool;
|
||||
}
|
||||
|
||||
/// A no-op UI controller for non-interactive contexts.
|
||||
pub struct NoOpUiController;
|
||||
|
||||
#[async_trait]
|
||||
impl UiController for NoOpUiController {
|
||||
async fn confirm(&self, _prompt: &str) -> bool {
|
||||
false // Always decline in non-interactive mode
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_auto_scroll() {
|
||||
let mut scroll = AutoScroll {
|
||||
content_len: 100,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Test on_viewport with stick_to_bottom
|
||||
scroll.on_viewport(10);
|
||||
assert_eq!(scroll.scroll, 90);
|
||||
|
||||
// Test user scroll up
|
||||
scroll.on_user_scroll(-10, 10);
|
||||
assert_eq!(scroll.scroll, 80);
|
||||
assert!(!scroll.stick_to_bottom);
|
||||
|
||||
// Test jump to bottom
|
||||
scroll.jump_to_bottom(10);
|
||||
assert!(scroll.stick_to_bottom);
|
||||
assert_eq!(scroll.scroll, 90);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_visual_selection() {
|
||||
let mut selection = VisualSelection::new();
|
||||
assert!(!selection.is_active());
|
||||
|
||||
selection.start_at((0, 0));
|
||||
assert!(selection.is_active());
|
||||
|
||||
selection.extend_to((2, 5));
|
||||
let normalized = selection.get_normalized();
|
||||
assert_eq!(normalized, Some(((0, 0), (2, 5))));
|
||||
|
||||
selection.clear();
|
||||
assert!(!selection.is_active());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_single_line() {
|
||||
let lines = vec!["Hello World".to_string()];
|
||||
let result = extract_text_from_selection(&lines, (0, 0), (0, 5));
|
||||
assert_eq!(result, Some("Hello".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_text_multi_line() {
|
||||
let lines = vec![
|
||||
"First line".to_string(),
|
||||
"Second line".to_string(),
|
||||
"Third line".to_string(),
|
||||
];
|
||||
let result = extract_text_from_selection(&lines, (0, 6), (2, 5));
|
||||
assert_eq!(result, Some("line\nSecond line\nThird".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_word_boundaries() {
|
||||
let line = "hello world test";
|
||||
assert_eq!(find_next_word_boundary(line, 0), Some(5));
|
||||
assert_eq!(find_next_word_boundary(line, 5), Some(6));
|
||||
assert_eq!(find_next_word_boundary(line, 6), Some(11));
|
||||
|
||||
assert_eq!(find_prev_word_boundary(line, 16), Some(12));
|
||||
assert_eq!(find_prev_word_boundary(line, 11), Some(6));
|
||||
assert_eq!(find_prev_word_boundary(line, 6), Some(0));
|
||||
}
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
use crate::{Error, Result};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use tokio::fs;
|
||||
|
||||
const LEDGER_VERSION: u32 = 1;
|
||||
const SECONDS_PER_HOUR: i64 = 60 * 60;
|
||||
const SECONDS_PER_WEEK: i64 = 7 * 24 * 60 * 60;
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
struct UsageRecord {
|
||||
timestamp: i64,
|
||||
prompt_tokens: u32,
|
||||
completion_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct LedgerFile {
|
||||
version: u32,
|
||||
providers: HashMap<String, VecDeque<UsageRecord>>,
|
||||
}
|
||||
|
||||
impl Default for LedgerFile {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
version: LEDGER_VERSION,
|
||||
providers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct UsageLedger {
|
||||
path: PathBuf,
|
||||
providers: HashMap<String, VecDeque<UsageRecord>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct UsageQuota {
|
||||
pub hourly_quota_tokens: Option<u64>,
|
||||
pub weekly_quota_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
|
||||
pub enum UsageWindow {
|
||||
Hour,
|
||||
Week,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
pub enum UsageBand {
|
||||
Normal = 0,
|
||||
Warning = 1,
|
||||
Critical = 2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct WindowMetrics {
|
||||
pub prompt_tokens: u64,
|
||||
pub completion_tokens: u64,
|
||||
pub total_tokens: u64,
|
||||
pub quota_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
impl WindowMetrics {
|
||||
pub fn percent_of_quota(&self) -> Option<f64> {
|
||||
let quota = self.quota_tokens?;
|
||||
if quota == 0 {
|
||||
return None;
|
||||
}
|
||||
Some(self.total_tokens as f64 / quota as f64)
|
||||
}
|
||||
|
||||
pub fn band(&self) -> UsageBand {
|
||||
match self.percent_of_quota() {
|
||||
Some(p) if p >= 0.95_f64 => UsageBand::Critical,
|
||||
Some(p) if p >= 0.80_f64 => UsageBand::Warning,
|
||||
_ => UsageBand::Normal,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct UsageSnapshot {
|
||||
pub provider: String,
|
||||
pub hourly: WindowMetrics,
|
||||
pub weekly: WindowMetrics,
|
||||
pub last_updated: Option<SystemTime>,
|
||||
}
|
||||
|
||||
impl UsageSnapshot {
|
||||
pub fn window(&self, window: UsageWindow) -> &WindowMetrics {
|
||||
match window {
|
||||
UsageWindow::Hour => &self.hourly,
|
||||
UsageWindow::Week => &self.weekly,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl UsageLedger {
|
||||
pub fn empty(path: PathBuf) -> Self {
|
||||
Self {
|
||||
path,
|
||||
providers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_or_default(path: PathBuf) -> Result<Self> {
|
||||
if !path.exists() {
|
||||
return Ok(Self {
|
||||
path,
|
||||
providers: HashMap::new(),
|
||||
});
|
||||
}
|
||||
|
||||
let contents = fs::read_to_string(&path)
|
||||
.await
|
||||
.map_err(|err| Error::Storage(format!("Failed to read usage ledger: {err}")))?;
|
||||
|
||||
let file: LedgerFile = match serde_json::from_str(&contents) {
|
||||
Ok(file) => file,
|
||||
Err(err) => {
|
||||
return Err(Error::Storage(format!(
|
||||
"Failed to parse usage ledger at {}: {err}",
|
||||
path.display()
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
path,
|
||||
providers: file.providers,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn persist(&self) -> Result<()> {
|
||||
if let Some(parent) = self.path.parent() {
|
||||
fs::create_dir_all(parent)
|
||||
.await
|
||||
.map_err(|err| Error::Storage(format!("Failed to create data directory: {err}")))?;
|
||||
}
|
||||
|
||||
let serialized = self.serialize()?;
|
||||
|
||||
fs::write(&self.path, serialized)
|
||||
.await
|
||||
.map_err(|err| Error::Storage(format!("Failed to write usage ledger: {err}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn record(
|
||||
&mut self,
|
||||
provider: &str,
|
||||
usage: &crate::types::TokenUsage,
|
||||
timestamp: SystemTime,
|
||||
) {
|
||||
let total_tokens = usage.total_tokens;
|
||||
if total_tokens == 0 {
|
||||
return;
|
||||
}
|
||||
|
||||
let ts = match timestamp.duration_since(UNIX_EPOCH) {
|
||||
Ok(duration) => duration.as_secs() as i64,
|
||||
Err(_) => 0,
|
||||
};
|
||||
|
||||
let entry = self.providers.entry(provider.to_string()).or_default();
|
||||
|
||||
entry.push_back(UsageRecord {
|
||||
timestamp: ts,
|
||||
prompt_tokens: usage.prompt_tokens,
|
||||
completion_tokens: usage.completion_tokens,
|
||||
});
|
||||
|
||||
self.prune_old(provider, ts);
|
||||
}
|
||||
|
||||
pub fn provider_keys(&self) -> impl Iterator<Item = &String> {
|
||||
self.providers.keys()
|
||||
}
|
||||
|
||||
pub fn serialize(&self) -> Result<String> {
|
||||
let file = LedgerFile {
|
||||
version: LEDGER_VERSION,
|
||||
providers: self.providers.clone(),
|
||||
};
|
||||
|
||||
serde_json::to_string_pretty(&file)
|
||||
.map_err(|err| Error::Storage(format!("Failed to serialize usage ledger: {err}")))
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.path
|
||||
}
|
||||
|
||||
pub fn snapshot(&self, provider: &str, quotas: UsageQuota, now: SystemTime) -> UsageSnapshot {
|
||||
let now_secs = now
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0))
|
||||
.as_secs() as i64;
|
||||
|
||||
let mut snapshot = UsageSnapshot {
|
||||
provider: provider.to_string(),
|
||||
hourly: WindowMetrics {
|
||||
quota_tokens: quotas.hourly_quota_tokens,
|
||||
..Default::default()
|
||||
},
|
||||
weekly: WindowMetrics {
|
||||
quota_tokens: quotas.weekly_quota_tokens,
|
||||
..Default::default()
|
||||
},
|
||||
last_updated: None,
|
||||
};
|
||||
|
||||
if let Some(records) = self.providers.get(provider) {
|
||||
for record in records {
|
||||
if now_secs - record.timestamp <= SECONDS_PER_HOUR {
|
||||
snapshot.hourly.prompt_tokens += record.prompt_tokens as u64;
|
||||
snapshot.hourly.completion_tokens += record.completion_tokens as u64;
|
||||
}
|
||||
|
||||
if now_secs - record.timestamp <= SECONDS_PER_WEEK {
|
||||
snapshot.weekly.prompt_tokens += record.prompt_tokens as u64;
|
||||
snapshot.weekly.completion_tokens += record.completion_tokens as u64;
|
||||
}
|
||||
}
|
||||
|
||||
snapshot.hourly.total_tokens =
|
||||
snapshot.hourly.prompt_tokens + snapshot.hourly.completion_tokens;
|
||||
snapshot.weekly.total_tokens =
|
||||
snapshot.weekly.prompt_tokens + snapshot.weekly.completion_tokens;
|
||||
|
||||
snapshot.last_updated = records.back().and_then(|record| {
|
||||
UNIX_EPOCH.checked_add(Duration::from_secs(record.timestamp as u64))
|
||||
});
|
||||
}
|
||||
|
||||
snapshot
|
||||
}
|
||||
|
||||
pub fn prune_old(&mut self, provider: &str, now_secs: i64) {
|
||||
if let Some(records) = self.providers.get_mut(provider) {
|
||||
while let Some(front) = records.front() {
|
||||
if now_secs - front.timestamp > SECONDS_PER_WEEK {
|
||||
records.pop_front();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn prune_all(&mut self, now: SystemTime) {
|
||||
let now_secs = now
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap_or_else(|_| Duration::from_secs(0))
|
||||
.as_secs() as i64;
|
||||
let provider_keys: Vec<String> = self.providers.keys().cloned().collect();
|
||||
for provider in provider_keys {
|
||||
self.prune_old(&provider, now_secs);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::TokenUsage;
|
||||
use std::time::{Duration, UNIX_EPOCH};
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn make_usage(prompt: u32, completion: u32) -> TokenUsage {
|
||||
TokenUsage {
|
||||
prompt_tokens: prompt,
|
||||
completion_tokens: completion,
|
||||
total_tokens: prompt.saturating_add(completion),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn records_and_summarizes_usage() {
|
||||
let temp = tempdir().expect("tempdir");
|
||||
let path = temp.path().join("ledger.json");
|
||||
let mut ledger = UsageLedger::empty(path);
|
||||
|
||||
let usage = make_usage(40, 10);
|
||||
let timestamp = UNIX_EPOCH + Duration::from_secs(1);
|
||||
ledger.record("ollama_cloud", &usage, timestamp);
|
||||
|
||||
let quotas = UsageQuota {
|
||||
hourly_quota_tokens: Some(100),
|
||||
weekly_quota_tokens: Some(1000),
|
||||
};
|
||||
|
||||
let snapshot = ledger.snapshot("ollama_cloud", quotas, UNIX_EPOCH + Duration::from_secs(2));
|
||||
|
||||
assert_eq!(snapshot.hourly.total_tokens, 50);
|
||||
assert_eq!(snapshot.weekly.total_tokens, 50);
|
||||
assert_eq!(snapshot.hourly.quota_tokens, Some(100));
|
||||
assert_eq!(snapshot.weekly.quota_tokens, Some(1000));
|
||||
assert_eq!(snapshot.hourly.band(), UsageBand::Normal);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prunes_records_outside_week() {
|
||||
let temp = tempdir().expect("tempdir");
|
||||
let path = temp.path().join("ledger.json");
|
||||
let mut ledger = UsageLedger::empty(path);
|
||||
|
||||
let old_usage = make_usage(30, 5);
|
||||
let recent_usage = make_usage(20, 5);
|
||||
|
||||
let base = UNIX_EPOCH;
|
||||
ledger.record("ollama_cloud", &old_usage, base);
|
||||
|
||||
// Advance beyond a week for the second record.
|
||||
let later = UNIX_EPOCH + Duration::from_secs(SECONDS_PER_WEEK as u64 + 120);
|
||||
ledger.record("ollama_cloud", &recent_usage, later);
|
||||
|
||||
let quotas = UsageQuota::default();
|
||||
let snapshot = ledger.snapshot("ollama_cloud", quotas, later);
|
||||
|
||||
assert_eq!(snapshot.hourly.total_tokens, 25);
|
||||
assert_eq!(snapshot.weekly.total_tokens, 25);
|
||||
}
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use jsonschema::{JSONSchema, ValidationError};
|
||||
use serde_json::{Value, json};
|
||||
|
||||
use crate::tools::WEB_SEARCH_TOOL_NAME;
|
||||
|
||||
pub struct SchemaValidator {
|
||||
schemas: HashMap<String, JSONSchema>,
|
||||
}
|
||||
|
||||
impl Default for SchemaValidator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SchemaValidator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
schemas: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_schema(&mut self, tool_name: &str, schema: Value) -> Result<()> {
|
||||
let compiled = JSONSchema::compile(&schema)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid schema for {}: {}", tool_name, e))?;
|
||||
|
||||
self.schemas.insert(tool_name.to_string(), compiled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate(&self, tool_name: &str, input: &Value) -> Result<()> {
|
||||
let schema = self
|
||||
.schemas
|
||||
.get(tool_name)
|
||||
.with_context(|| format!("No schema registered for tool: {}", tool_name))?;
|
||||
|
||||
if let Err(errors) = schema.validate(input) {
|
||||
let error_messages: Vec<String> = errors.map(format_validation_error).collect();
|
||||
|
||||
return Err(anyhow::anyhow!(
|
||||
"Input validation failed for {}: {}",
|
||||
tool_name,
|
||||
error_messages.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn format_validation_error(error: ValidationError) -> String {
|
||||
format!("Validation error at {}: {}", error.instance_path, error)
|
||||
}
|
||||
|
||||
pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
||||
let mut schemas = HashMap::new();
|
||||
|
||||
let web_search_schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 500
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
});
|
||||
|
||||
schemas.insert(WEB_SEARCH_TOOL_NAME.to_string(), web_search_schema.clone());
|
||||
|
||||
schemas.insert(
|
||||
"code_exec".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"language": {
|
||||
"type": "string",
|
||||
"enum": ["python", "javascript", "bash", "rust"]
|
||||
},
|
||||
"code": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 10000
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["language", "code"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
);
|
||||
|
||||
schemas
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
#![allow(clippy::cast_possible_truncation)]
|
||||
|
||||
use unicode_segmentation::UnicodeSegmentation;
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct ScreenPos {
|
||||
pub row: u16,
|
||||
pub col: u16,
|
||||
}
|
||||
|
||||
pub fn build_cursor_map(text: &str, width: u16) -> Vec<ScreenPos> {
|
||||
assert!(width > 0);
|
||||
let width = width as usize;
|
||||
let mut pos_map = vec![ScreenPos { row: 0, col: 0 }; text.len() + 1];
|
||||
let mut row = 0;
|
||||
let mut col = 0;
|
||||
|
||||
let mut word_start_idx = 0;
|
||||
let mut word_start_col = 0;
|
||||
|
||||
for (byte_offset, grapheme) in text.grapheme_indices(true) {
|
||||
let grapheme_width = UnicodeWidthStr::width(grapheme);
|
||||
|
||||
if grapheme == "\n" {
|
||||
row += 1;
|
||||
col = 0;
|
||||
word_start_col = 0;
|
||||
word_start_idx = byte_offset + grapheme.len();
|
||||
// Set position for the end of this grapheme and any intermediate bytes
|
||||
let end_pos = ScreenPos {
|
||||
row: row as u16,
|
||||
col: col as u16,
|
||||
};
|
||||
for i in 1..=grapheme.len() {
|
||||
if byte_offset + i < pos_map.len() {
|
||||
pos_map[byte_offset + i] = end_pos;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if grapheme.chars().all(char::is_whitespace) {
|
||||
if col + grapheme_width > width {
|
||||
// Whitespace causes wrap
|
||||
row += 1;
|
||||
col = 1; // Position after wrapping space
|
||||
word_start_col = 1;
|
||||
word_start_idx = byte_offset + grapheme.len();
|
||||
} else {
|
||||
col += grapheme_width;
|
||||
word_start_col = col;
|
||||
word_start_idx = byte_offset + grapheme.len();
|
||||
}
|
||||
} else if col + grapheme_width > width {
|
||||
if word_start_col > 0 && byte_offset == word_start_idx {
|
||||
// This is the first character of a new word that won't fit, wrap it
|
||||
row += 1;
|
||||
col = grapheme_width;
|
||||
} else if word_start_col == 0 {
|
||||
// No previous word boundary, hard break
|
||||
row += 1;
|
||||
col = grapheme_width;
|
||||
} else {
|
||||
// This is part of a word already on the line, let it extend beyond width
|
||||
col += grapheme_width;
|
||||
}
|
||||
} else {
|
||||
col += grapheme_width;
|
||||
}
|
||||
|
||||
// Set position for the end of this grapheme and any intermediate bytes
|
||||
let end_pos = ScreenPos {
|
||||
row: row as u16,
|
||||
col: col as u16,
|
||||
};
|
||||
for i in 1..=grapheme.len() {
|
||||
if byte_offset + i < pos_map.len() {
|
||||
pos_map[byte_offset + i] = end_pos;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pos_map
|
||||
}
|
||||
|
||||
pub fn byte_to_screen_pos(text: &str, byte_idx: usize, width: u16) -> ScreenPos {
|
||||
let pos_map = build_cursor_map(text, width);
|
||||
pos_map[byte_idx.min(text.len())]
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user