Compare commits
92 Commits
dev
...
4a07b97eab
| Author | SHA1 | Date | |
|---|---|---|---|
| 4a07b97eab | |||
| 10c8e2baae | |||
| 09c8c9d83e | |||
| 5caf502009 | |||
| 04a7085007 | |||
| 6022aeb2b0 | |||
| e77e33ce2f | |||
| f87e5d2796 | |||
| 3c436fda54 | |||
| 173403379f | |||
| 688d1fe58a | |||
| b1b95a4560 | |||
| a024a764d6 | |||
| 686526bbd4 | |||
| 5134462deb | |||
| d7ddc365ec | |||
| 6108b9e3d1 | |||
| a6cf8585ef | |||
| baf833427a | |||
| d21945dbc0 | |||
| 7f39bf1eca | |||
| dcda8216dc | |||
| ff49e7ce93 | |||
| b63d26f0cd | |||
| 64fd3206a2 | |||
| 2a651ebd7b | |||
| 491fd049b0 | |||
| c9e2f9bae6 | |||
| 7b87459a72 | |||
| 4935a64a13 | |||
| a84c8a425d | |||
| 16c0e71147 | |||
| 0728262a9e | |||
| 7aa80fb0a4 | |||
| 28b6eb0a9a | |||
| 353c0a8239 | |||
| 44b07c8e27 | |||
| 76e59c2d0e | |||
| c92e07b866 | |||
| 9aa8722ec3 | |||
| 7daa4f4ebe | |||
| a788b8941e | |||
| 16bc534837 | |||
| eef0e3dea0 | |||
| 5d9ecec82c | |||
| 6980640324 | |||
| a0868a9b49 | |||
| 877ece07be | |||
| f6a3f235df | |||
| a4f7a45e56 | |||
| 94ef08db6b | |||
| 57942219a8 | |||
| 03244e8d24 | |||
| d7066d7d37 | |||
| 124db19e68 | |||
| e89da02d49 | |||
| cf0a8f21d5 | |||
| 2d45406982 | |||
| f592840d39 | |||
| 9090bddf68 | |||
| 4981a63224 | |||
| 1238bbe000 | |||
| f29f306692 | |||
| 9024e2b914 | |||
| 6849d5ef12 | |||
| 3c6e689de9 | |||
| 1994367a2e | |||
| c3a92a092b | |||
| 6a94373c4f | |||
| 83280f68cc | |||
| 21759898fb | |||
| 02df6d893c | |||
| 8f9d601fdc | |||
| 40e42c8918 | |||
| 6e12bb3acb | |||
| 16b6f24e3e | |||
| 25628d1d58 | |||
| e813736b47 | |||
| 7e2c6ea037 | |||
| 3f6d7d56f6 | |||
| bbb94367e1 | |||
| 79fdafce97 | |||
| 24671f5f2a | |||
| e0b14a42f2 | |||
| 3e8788dd44 | |||
| 38a4c55eaa | |||
| c7b7fe98ec | |||
| 4820a6706f | |||
| 3308b483f7 | |||
| 4ce4ac0b0e | |||
| 3722840d2c | |||
| 02f25b7bec |
@@ -1,20 +0,0 @@
|
||||
[target.x86_64-unknown-linux-musl]
|
||||
linker = "x86_64-linux-gnu-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.aarch64-unknown-linux-gnu]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
|
||||
[target.aarch64-unknown-linux-musl]
|
||||
linker = "aarch64-linux-gnu-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.armv7-unknown-linux-gnueabihf]
|
||||
linker = "arm-linux-gnueabihf-gcc"
|
||||
|
||||
[target.armv7-unknown-linux-musleabihf]
|
||||
linker = "arm-linux-gnueabihf-gcc"
|
||||
rustflags = ["-C", "target-feature=+crt-static", "-C", "link-arg=-lgcc"]
|
||||
|
||||
[target.x86_64-pc-windows-gnu]
|
||||
linker = "x86_64-w64-mingw32-gcc"
|
||||
34
.github/workflows/macos-check.yml
vendored
34
.github/workflows/macos-check.yml
vendored
@@ -1,34 +0,0 @@
|
||||
name: macos-check
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- dev
|
||||
pull_request:
|
||||
branches:
|
||||
- dev
|
||||
|
||||
jobs:
|
||||
build:
|
||||
name: cargo check (macOS)
|
||||
runs-on: macos-latest
|
||||
steps:
|
||||
- name: Checkout sources
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
|
||||
- name: Cache Cargo registry
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
${{ runner.os }}-cargo-
|
||||
|
||||
- name: Cargo check
|
||||
run: cargo check --workspace --all-features
|
||||
39
.gitignore
vendored
39
.gitignore
vendored
@@ -1,13 +1,12 @@
|
||||
### Custom
|
||||
AGENTS.md
|
||||
CLAUDE.md
|
||||
|
||||
### Rust template
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
dev/
|
||||
.agents/
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
@@ -19,17 +18,10 @@ Cargo.lock
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
# RustRover
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
### JetBrains template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
.idea/
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
@@ -60,14 +52,15 @@ Cargo.lock
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
.idea/artifacts
|
||||
.idea/compiler.xml
|
||||
.idea/jarRepositories.xml
|
||||
.idea/modules.xml
|
||||
.idea/*.iml
|
||||
.idea/modules
|
||||
*.iml
|
||||
*.ipr
|
||||
.idea
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
@@ -104,3 +97,9 @@ fabric.properties
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### rust-analyzer template
|
||||
# Can be generated by other build systems other than cargo (ex: bazelbuild/rust_rules)
|
||||
rust-project.json
|
||||
|
||||
|
||||
|
||||
@@ -1,35 +0,0 @@
|
||||
# Pre-commit hooks configuration
|
||||
# See https://pre-commit.com for more information
|
||||
|
||||
repos:
|
||||
# General file checks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
- id: end-of-file-fixer
|
||||
- id: check-yaml
|
||||
args: ['--allow-multiple-documents']
|
||||
- id: check-toml
|
||||
- id: check-merge-conflict
|
||||
- id: check-added-large-files
|
||||
args: ['--maxkb=1000']
|
||||
- id: mixed-line-ending
|
||||
|
||||
# Rust formatting
|
||||
- repo: https://github.com/doublify/pre-commit-rust
|
||||
rev: v1.0
|
||||
hooks:
|
||||
- id: fmt
|
||||
name: cargo fmt
|
||||
description: Format Rust code with rustfmt
|
||||
- id: cargo-check
|
||||
name: cargo check
|
||||
description: Check Rust code compilation
|
||||
- id: clippy
|
||||
name: cargo clippy
|
||||
description: Lint Rust code with clippy
|
||||
args: ['--all-features', '--', '-D', 'warnings']
|
||||
|
||||
# Optional: run on all files when config changes
|
||||
default_install_hook_types: [pre-commit, pre-push]
|
||||
197
.woodpecker.yml
197
.woodpecker.yml
@@ -1,197 +0,0 @@
|
||||
---
|
||||
kind: pipeline
|
||||
name: pr-checks
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- pull_request
|
||||
|
||||
steps:
|
||||
- name: fmt-clippy-test
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add rustfmt clippy
|
||||
- cargo fmt --all -- --check
|
||||
- cargo clippy --workspace --all-features -- -D warnings
|
||||
- cargo test --workspace --all-features
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: security-audit
|
||||
|
||||
when:
|
||||
event:
|
||||
- push
|
||||
- cron
|
||||
branch:
|
||||
- dev
|
||||
cron: weekly-security
|
||||
|
||||
steps:
|
||||
- name: cargo-audit
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- cargo install cargo-audit --locked
|
||||
- cargo audit
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release-tests
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
steps:
|
||||
- name: workspace-tests
|
||||
image: rust:1.83
|
||||
commands:
|
||||
- rustup component add llvm-tools-preview
|
||||
- cargo install cargo-llvm-cov --locked
|
||||
- cargo llvm-cov --workspace --all-features --summary-only
|
||||
- cargo llvm-cov --workspace --all-features --lcov --output-path coverage.lcov --no-run
|
||||
|
||||
---
|
||||
kind: pipeline
|
||||
name: release
|
||||
|
||||
when:
|
||||
event: tag
|
||||
tag: v*
|
||||
|
||||
variables:
|
||||
- &rust_image 'rust:1.83'
|
||||
|
||||
depends_on:
|
||||
- release-tests
|
||||
|
||||
matrix:
|
||||
include:
|
||||
# Linux
|
||||
- TARGET: x86_64-unknown-linux-gnu
|
||||
ARTIFACT: owlen-linux-x86_64-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: x86_64-unknown-linux-musl
|
||||
ARTIFACT: owlen-linux-x86_64-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: aarch64-unknown-linux-gnu
|
||||
ARTIFACT: owlen-linux-aarch64-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: aarch64-unknown-linux-musl
|
||||
ARTIFACT: owlen-linux-aarch64-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: armv7-unknown-linux-gnueabihf
|
||||
ARTIFACT: owlen-linux-armv7-gnu
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
- TARGET: armv7-unknown-linux-musleabihf
|
||||
ARTIFACT: owlen-linux-armv7-musl
|
||||
PLATFORM: linux
|
||||
EXT: ""
|
||||
# Windows
|
||||
- TARGET: x86_64-pc-windows-gnu
|
||||
ARTIFACT: owlen-windows-x86_64
|
||||
PLATFORM: windows
|
||||
EXT: ".exe"
|
||||
|
||||
steps:
|
||||
- name: build
|
||||
image: *rust_image
|
||||
commands:
|
||||
# Install cross-compilation tools
|
||||
- apt-get update
|
||||
- apt-get install -y musl-tools gcc-aarch64-linux-gnu g++-aarch64-linux-gnu gcc-arm-linux-gnueabihf g++-arm-linux-gnueabihf mingw-w64 zip
|
||||
|
||||
# Verify cross-compilers are installed
|
||||
- which aarch64-linux-gnu-gcc || echo "aarch64-linux-gnu-gcc not found!"
|
||||
- which arm-linux-gnueabihf-gcc || echo "arm-linux-gnueabihf-gcc not found!"
|
||||
- which x86_64-w64-mingw32-gcc || echo "x86_64-w64-mingw32-gcc not found!"
|
||||
|
||||
# Add rust target
|
||||
- rustup target add ${TARGET}
|
||||
|
||||
# Set up cross-compilation environment variables and build
|
||||
- |
|
||||
case "${TARGET}" in
|
||||
aarch64-unknown-linux-gnu)
|
||||
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CC_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CXX_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-g++
|
||||
export AR_aarch64_unknown_linux_gnu=/usr/bin/aarch64-linux-gnu-ar
|
||||
;;
|
||||
aarch64-unknown-linux-musl)
|
||||
export CARGO_TARGET_AARCH64_UNKNOWN_LINUX_MUSL_LINKER=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CC_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-gcc
|
||||
export CXX_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-g++
|
||||
export AR_aarch64_unknown_linux_musl=/usr/bin/aarch64-linux-gnu-ar
|
||||
;;
|
||||
armv7-unknown-linux-gnueabihf)
|
||||
export CARGO_TARGET_ARMV7_UNKNOWN_LINUX_GNUEABIHF_LINKER=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CC_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CXX_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-g++
|
||||
export AR_armv7_unknown_linux_gnueabihf=/usr/bin/arm-linux-gnueabihf-ar
|
||||
;;
|
||||
armv7-unknown-linux-musleabihf)
|
||||
export CARGO_TARGET_ARMV7_UNKNOWN_LINUX_MUSLEABIHF_LINKER=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CC_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-gcc
|
||||
export CXX_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-g++
|
||||
export AR_armv7_unknown_linux_musleabihf=/usr/bin/arm-linux-gnueabihf-ar
|
||||
;;
|
||||
x86_64-pc-windows-gnu)
|
||||
export CARGO_TARGET_X86_64_PC_WINDOWS_GNU_LINKER=/usr/bin/x86_64-w64-mingw32-gcc
|
||||
export CC_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-gcc
|
||||
export CXX_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-g++
|
||||
export AR_x86_64_pc_windows_gnu=/usr/bin/x86_64-w64-mingw32-ar
|
||||
;;
|
||||
esac
|
||||
|
||||
# Build the project
|
||||
cargo build --release --all-features --target ${TARGET}
|
||||
|
||||
- name: package
|
||||
image: *rust_image
|
||||
commands:
|
||||
- apt-get update && apt-get install -y zip
|
||||
- mkdir -p dist
|
||||
- |
|
||||
if [ "${PLATFORM}" = "windows" ]; then
|
||||
cp target/${TARGET}/release/owlen.exe dist/owlen.exe
|
||||
cp target/${TARGET}/release/owlen-code.exe dist/owlen-code.exe
|
||||
cd dist
|
||||
zip -9 ${ARTIFACT}.zip owlen.exe owlen-code.exe
|
||||
cd ..
|
||||
mv dist/${ARTIFACT}.zip .
|
||||
sha256sum ${ARTIFACT}.zip > ${ARTIFACT}.zip.sha256
|
||||
else
|
||||
cp target/${TARGET}/release/owlen dist/owlen
|
||||
cp target/${TARGET}/release/owlen-code dist/owlen-code
|
||||
cd dist
|
||||
tar czf ${ARTIFACT}.tar.gz owlen owlen-code
|
||||
cd ..
|
||||
mv dist/${ARTIFACT}.tar.gz .
|
||||
sha256sum ${ARTIFACT}.tar.gz > ${ARTIFACT}.tar.gz.sha256
|
||||
fi
|
||||
|
||||
- name: release-notes
|
||||
image: *rust_image
|
||||
commands:
|
||||
- scripts/release-notes.sh "${CI_COMMIT_TAG}" release-notes.md
|
||||
|
||||
- name: release
|
||||
image: plugins/gitea-release
|
||||
settings:
|
||||
api_key:
|
||||
from_secret: gitea_token
|
||||
base_url: https://somegit.dev
|
||||
files:
|
||||
- ${ARTIFACT}.tar.gz
|
||||
- ${ARTIFACT}.tar.gz.sha256
|
||||
- ${ARTIFACT}.zip
|
||||
- ${ARTIFACT}.zip.sha256
|
||||
title: Release ${CI_COMMIT_TAG}
|
||||
note_file: release-notes.md
|
||||
798
AGENTS.md
798
AGENTS.md
@@ -1,798 +0,0 @@
|
||||
# AGENTS.md - AI Agent Instructions for Owlen Development
|
||||
|
||||
This document provides comprehensive context and guidelines for AI agents (Claude, GPT-4, etc.) working on the Owlen codebase.
|
||||
|
||||
## Project Overview
|
||||
|
||||
**Owlen** is a local-first, terminal-based AI assistant built in Rust using the Ratatui TUI framework. It implements a Model Context Protocol (MCP) architecture for modular tool execution and supports both local (Ollama) and cloud LLM providers.
|
||||
|
||||
**Core Philosophy:**
|
||||
- **Local-first**: Prioritize local LLMs (Ollama) with cloud as fallback
|
||||
- **Privacy-focused**: No telemetry, user data stays on device
|
||||
- **MCP-native**: All operations through MCP servers for modularity
|
||||
- **Terminal-native**: Vim-style modal interaction in a beautiful TUI
|
||||
|
||||
**Current Status:** v1.0 - MCP-only architecture (Phase 10 complete)
|
||||
|
||||
## Architecture
|
||||
|
||||
### Project Structure
|
||||
|
||||
```
|
||||
owlen/
|
||||
├── crates/
|
||||
│ ├── owlen-core/ # Core types, config, provider traits
|
||||
│ ├── owlen-tui/ # Ratatui-based terminal interface
|
||||
│ ├── owlen-cli/ # Command-line interface
|
||||
│ ├── owlen-ollama/ # Ollama provider implementation
|
||||
│ ├── owlen-mcp-llm-server/ # LLM inference as MCP server
|
||||
│ ├── owlen-mcp-client/ # MCP client library
|
||||
│ ├── owlen-mcp-server/ # Base MCP server framework
|
||||
│ ├── owlen-mcp-code-server/ # Code execution in Docker
|
||||
│ └── owlen-mcp-prompt-server/ # Prompt management server
|
||||
├── docs/ # Documentation
|
||||
├── themes/ # TUI color themes
|
||||
└── .agents/ # Agent development plans
|
||||
```
|
||||
|
||||
### Key Technologies
|
||||
|
||||
- **Language**: Rust 1.83+
|
||||
- **TUI**: Ratatui with Crossterm backend
|
||||
- **Async Runtime**: Tokio
|
||||
- **Config**: TOML (serde)
|
||||
- **HTTP Client**: reqwest
|
||||
- **LLM Providers**: Ollama (primary), with extensibility for OpenAI/Anthropic
|
||||
- **Protocol**: JSON-RPC 2.0 over STDIO/HTTP/WebSocket
|
||||
|
||||
## Current Features (v1.0)
|
||||
|
||||
### Core Capabilities
|
||||
|
||||
1. **MCP Architecture** (Phase 3-10 complete)
|
||||
- All LLM interactions via MCP servers
|
||||
- Local and remote MCP client support
|
||||
- STDIO, HTTP, WebSocket transports
|
||||
- Automatic failover with health checks
|
||||
|
||||
2. **Provider System**
|
||||
- Ollama (local and cloud)
|
||||
- Configurable per-provider settings
|
||||
- API key management with env variable expansion
|
||||
- Model switching via TUI (`:m` command)
|
||||
|
||||
3. **Agentic Loop** (ReAct pattern)
|
||||
- THOUGHT → ACTION → OBSERVATION cycle
|
||||
- Tool discovery and execution
|
||||
- Configurable iteration limits
|
||||
- Emergency stop (Ctrl+C)
|
||||
|
||||
4. **Mode System**
|
||||
- Chat mode: Limited tool availability
|
||||
- Code mode: Full tool access
|
||||
- Tool filtering by mode
|
||||
- Runtime mode switching
|
||||
|
||||
5. **Session Management**
|
||||
- Auto-save conversations
|
||||
- Session persistence with encryption
|
||||
- Description generation
|
||||
- Session timeout management
|
||||
|
||||
6. **Security**
|
||||
- Docker sandboxing for code execution
|
||||
- Tool whitelisting
|
||||
- Permission prompts for dangerous operations
|
||||
- Network isolation options
|
||||
|
||||
### TUI Features
|
||||
|
||||
- Vim-style modal editing (Normal, Insert, Visual, Command modes)
|
||||
- Multi-panel layout (conversation, status, input)
|
||||
- Syntax highlighting for code blocks
|
||||
- Theme system (10+ built-in themes)
|
||||
- Scrollback history (configurable limit)
|
||||
- Word wrap and visual selection
|
||||
|
||||
## Development Guidelines
|
||||
|
||||
### Code Style
|
||||
|
||||
1. **Rust Best Practices**
|
||||
- Use `rustfmt` (pre-commit hook enforced)
|
||||
- Run `cargo clippy` before commits
|
||||
- Prefer `Result` over `panic!` for errors
|
||||
- Document public APIs with `///` comments
|
||||
|
||||
2. **Error Handling**
|
||||
- Use `owlen_core::Error` enum for all errors
|
||||
- Chain errors with context (`.map_err(|e| Error::X(format!(...)))`)
|
||||
- Never unwrap in library code (tests OK)
|
||||
|
||||
3. **Async Patterns**
|
||||
- All I/O operations must be async
|
||||
- Use `tokio::spawn` for background tasks
|
||||
- Prefer `tokio::sync::mpsc` for channels
|
||||
- Always set timeouts for network operations
|
||||
|
||||
4. **Testing**
|
||||
- Unit tests in same file (`#[cfg(test)] mod tests`)
|
||||
- Use mock implementations from `test_utils` modules
|
||||
- Integration tests in `crates/*/tests/`
|
||||
- All public APIs must have tests
|
||||
|
||||
### File Organization
|
||||
|
||||
**When editing existing files:**
|
||||
1. Read the entire file first (use `Read` tool)
|
||||
2. Preserve existing code style and formatting
|
||||
3. Update related tests in the same commit
|
||||
4. Keep changes atomic and focused
|
||||
|
||||
**When creating new files:**
|
||||
1. Check `crates/owlen-core/src/` for similar modules
|
||||
2. Follow existing module structure
|
||||
3. Add to `lib.rs` with appropriate visibility
|
||||
4. Document module purpose with `//!` header
|
||||
|
||||
### Configuration
|
||||
|
||||
**Config file**: `~/.config/owlen/config.toml`
|
||||
|
||||
Example structure:
|
||||
```toml
|
||||
[general]
|
||||
default_provider = "ollama"
|
||||
default_model = "llama3.2:latest"
|
||||
enable_streaming = true
|
||||
|
||||
[mcp]
|
||||
# MCP is always enabled in v1.0+
|
||||
|
||||
[providers.ollama]
|
||||
provider_type = "ollama"
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
[providers.ollama-cloud]
|
||||
provider_type = "ollama-cloud"
|
||||
base_url = "https://ollama.com"
|
||||
api_key = "$OLLAMA_API_KEY"
|
||||
|
||||
[ui]
|
||||
theme = "default_dark"
|
||||
word_wrap = true
|
||||
|
||||
[security]
|
||||
enable_sandboxing = true
|
||||
allowed_tools = ["web_search", "code_exec"]
|
||||
```
|
||||
|
||||
### Common Tasks
|
||||
|
||||
#### Adding a New Provider
|
||||
|
||||
1. Create `crates/owlen-{provider}/` crate
|
||||
2. Implement `owlen_core::provider::Provider` trait
|
||||
3. Add to `owlen_core::router::ProviderRouter`
|
||||
4. Update config schema in `owlen_core::config`
|
||||
5. Add tests with `MockProvider` pattern
|
||||
6. Document in `docs/provider-implementation.md`
|
||||
|
||||
#### Adding a New MCP Server
|
||||
|
||||
1. Create `crates/owlen-mcp-{name}-server/` crate
|
||||
2. Implement JSON-RPC 2.0 protocol handlers
|
||||
3. Define tool descriptors with JSON schemas
|
||||
4. Add sandboxing/security checks
|
||||
5. Register in `mcp_servers` config array
|
||||
6. Document tool capabilities
|
||||
|
||||
#### Adding a TUI Feature
|
||||
|
||||
1. Modify `crates/owlen-tui/src/chat_app.rs`
|
||||
2. Update keybinding handlers
|
||||
3. Extend UI rendering in `draw()` method
|
||||
4. Add to help screen (`?` command)
|
||||
5. Test with different terminal sizes
|
||||
6. Ensure theme compatibility
|
||||
|
||||
## Feature Parity Roadmap
|
||||
|
||||
Based on analysis of OpenAI Codex and Claude Code, here are prioritized features to implement:
|
||||
|
||||
### Phase 11: MCP Client Enhancement (HIGHEST PRIORITY)
|
||||
|
||||
**Goal**: Full MCP client capabilities to access ecosystem tools
|
||||
|
||||
**Features:**
|
||||
1. **MCP Server Management**
|
||||
- `owlen mcp add/list/remove` commands
|
||||
- Three config scopes: local, project (`.mcp.json`), user
|
||||
- Environment variable expansion in config
|
||||
- OAuth 2.0 authentication for remote servers
|
||||
|
||||
2. **MCP Resource References**
|
||||
- `@github:issue://123` syntax
|
||||
- `@postgres:schema://users` syntax
|
||||
- Auto-completion for resources
|
||||
|
||||
3. **MCP Prompts as Slash Commands**
|
||||
- `/mcp__github__list_prs`
|
||||
- Dynamic command registration
|
||||
|
||||
**Implementation:**
|
||||
- Extend `owlen-mcp-client` crate
|
||||
- Add `.mcp.json` parsing to `owlen-core::config`
|
||||
- Update TUI command parser for `@` and `/mcp__` syntax
|
||||
- Add OAuth flow to TUI
|
||||
|
||||
**Files to modify:**
|
||||
- `crates/owlen-mcp-client/src/lib.rs`
|
||||
- `crates/owlen-core/src/config.rs`
|
||||
- `crates/owlen-tui/src/command_parser.rs`
|
||||
|
||||
### Phase 12: Approval & Sandbox System (HIGHEST PRIORITY)
|
||||
|
||||
**Goal**: Safe agentic behavior with user control
|
||||
|
||||
**Features:**
|
||||
1. **Three-tier Approval Modes**
|
||||
- `suggest`: Approve ALL file writes and shell commands (default)
|
||||
- `auto-edit`: Auto-approve file changes, prompt for shell
|
||||
- `full-auto`: Auto-approve everything (requires Git repo)
|
||||
|
||||
2. **Platform-specific Sandboxing**
|
||||
- Linux: Docker with network isolation
|
||||
- macOS: Apple Seatbelt (`sandbox-exec`)
|
||||
- Windows: AppContainer or Job Objects
|
||||
|
||||
3. **Permission Management**
|
||||
- `/permissions` command in TUI
|
||||
- Tool allowlist (e.g., `Edit`, `Bash(git commit:*)`)
|
||||
- Stored in `.owlen/settings.json` (project) or `~/.owlen.json` (user)
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-core::approval` module
|
||||
- Extend `owlen-core::sandbox` with platform detection
|
||||
- Update `owlen-mcp-code-server` to use new sandbox
|
||||
- Add permission storage to config system
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-core/src/approval.rs`
|
||||
- `crates/owlen-core/src/sandbox/linux.rs`
|
||||
- `crates/owlen-core/src/sandbox/macos.rs`
|
||||
- `crates/owlen-core/src/sandbox/windows.rs`
|
||||
|
||||
### Phase 13: Project Documentation System (HIGH PRIORITY)
|
||||
|
||||
**Goal**: Massive usability improvement with project context
|
||||
|
||||
**Features:**
|
||||
1. **OWLEN.md System**
|
||||
- `OWLEN.md` at repo root (checked into git)
|
||||
- `OWLEN.local.md` (gitignored, personal)
|
||||
- `~/.config/owlen/OWLEN.md` (global)
|
||||
- Support nested OWLEN.md in monorepos
|
||||
|
||||
2. **Auto-generation**
|
||||
- `/init` command to generate project-specific OWLEN.md
|
||||
- Analyze codebase structure
|
||||
- Detect build system, test framework
|
||||
- Suggest common commands
|
||||
|
||||
3. **Live Updates**
|
||||
- `#` command to add instructions to OWLEN.md
|
||||
- Context-aware insertion (relevant section)
|
||||
|
||||
**Contents of OWLEN.md:**
|
||||
- Common bash commands
|
||||
- Code style guidelines
|
||||
- Testing instructions
|
||||
- Core files and utilities
|
||||
- Known quirks/warnings
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-core::project_doc` module
|
||||
- File discovery algorithm (walk up directory tree)
|
||||
- Markdown parser for sections
|
||||
- TUI commands: `/init`, `#`
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-core/src/project_doc.rs`
|
||||
- `crates/owlen-tui/src/commands/init.rs`
|
||||
|
||||
### Phase 14: Non-Interactive Mode (HIGH PRIORITY)
|
||||
|
||||
**Goal**: Enable CI/CD integration and automation
|
||||
|
||||
**Features:**
|
||||
1. **Headless Execution**
|
||||
```bash
|
||||
owlen exec "fix linting errors" --approval-mode auto-edit
|
||||
owlen --quiet "update CHANGELOG" --json
|
||||
```
|
||||
|
||||
2. **Environment Variables**
|
||||
- `OWLEN_QUIET_MODE=1`
|
||||
- `OWLEN_DISABLE_PROJECT_DOC=1`
|
||||
- `OWLEN_APPROVAL_MODE=full-auto`
|
||||
|
||||
3. **JSON Output**
|
||||
- Structured output for parsing
|
||||
- Exit codes for success/failure
|
||||
- Progress events on stderr
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-cli` subcommand: `exec`
|
||||
- Extend `owlen-core::session` with non-interactive mode
|
||||
- Add JSON serialization for results
|
||||
- Environment variable parsing in config
|
||||
|
||||
**Files to modify:**
|
||||
- `crates/owlen-cli/src/main.rs`
|
||||
- `crates/owlen-core/src/session.rs`
|
||||
|
||||
### Phase 15: Multi-Provider Expansion (HIGH PRIORITY)
|
||||
|
||||
**Goal**: Support cloud providers while maintaining local-first
|
||||
|
||||
**Providers to add:**
|
||||
1. OpenAI (GPT-4, o1, o4-mini)
|
||||
2. Anthropic (Claude 3.5 Sonnet, Opus)
|
||||
3. Google (Gemini Ultra, Pro)
|
||||
4. Mistral AI
|
||||
|
||||
**Configuration:**
|
||||
```toml
|
||||
[providers.openai]
|
||||
api_key = "${OPENAI_API_KEY}"
|
||||
model = "o4-mini"
|
||||
enabled = true
|
||||
|
||||
[providers.anthropic]
|
||||
api_key = "${ANTHROPIC_API_KEY}"
|
||||
model = "claude-3-5-sonnet"
|
||||
enabled = true
|
||||
```
|
||||
|
||||
**Runtime Switching:**
|
||||
```
|
||||
:model ollama/starcoder
|
||||
:model openai/o4-mini
|
||||
:model anthropic/claude-3-5-sonnet
|
||||
```
|
||||
|
||||
**Implementation:**
|
||||
- Create `owlen-openai`, `owlen-anthropic`, `owlen-google` crates
|
||||
- Implement `Provider` trait for each
|
||||
- Add runtime model switching to TUI
|
||||
- Maintain Ollama as default
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-openai/src/lib.rs`
|
||||
- `crates/owlen-anthropic/src/lib.rs`
|
||||
- `crates/owlen-google/src/lib.rs`
|
||||
|
||||
### Phase 16: Custom Slash Commands (MEDIUM PRIORITY)
|
||||
|
||||
**Goal**: User and team-defined workflows
|
||||
|
||||
**Features:**
|
||||
1. **Command Directories**
|
||||
- `~/.owlen/commands/` (user, available everywhere)
|
||||
- `.owlen/commands/` (project, checked into git)
|
||||
- Support `$ARGUMENTS` keyword
|
||||
|
||||
2. **Example Structure**
|
||||
```markdown
|
||||
# .owlen/commands/fix-github-issue.md
|
||||
Please analyze and fix GitHub issue: $ARGUMENTS.
|
||||
1. Use `gh issue view` to get details
|
||||
2. Implement changes
|
||||
3. Write and run tests
|
||||
4. Create PR
|
||||
```
|
||||
|
||||
3. **TUI Integration**
|
||||
- Auto-complete for custom commands
|
||||
- Help text from command files
|
||||
- Parameter validation
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-core::commands` module
|
||||
- Command discovery and parsing
|
||||
- Template expansion
|
||||
- TUI command registration
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-core/src/commands.rs`
|
||||
- `crates/owlen-tui/src/commands/custom.rs`
|
||||
|
||||
### Phase 17: Plugin System (MEDIUM PRIORITY)
|
||||
|
||||
**Goal**: One-command installation of tool collections
|
||||
|
||||
**Features:**
|
||||
1. **Plugin Structure**
|
||||
```json
|
||||
{
|
||||
"name": "github-workflow",
|
||||
"version": "1.0.0",
|
||||
"commands": [
|
||||
{"name": "pr", "file": "commands/pr.md"}
|
||||
],
|
||||
"mcp_servers": [
|
||||
{
|
||||
"name": "github",
|
||||
"command": "${OWLEN_PLUGIN_ROOT}/bin/github-mcp"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
2. **Installation**
|
||||
```bash
|
||||
owlen plugin install github-workflow
|
||||
owlen plugin list
|
||||
owlen plugin remove github-workflow
|
||||
```
|
||||
|
||||
3. **Discovery**
|
||||
- `~/.owlen/plugins/` directory
|
||||
- Git repository URLs
|
||||
- Plugin registry (future)
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-core::plugins` module
|
||||
- Plugin manifest parser
|
||||
- Installation/removal logic
|
||||
- Sandboxing for plugin code
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-core/src/plugins.rs`
|
||||
- `crates/owlen-cli/src/commands/plugin.rs`
|
||||
|
||||
### Phase 18: Extended Thinking Modes (MEDIUM PRIORITY)
|
||||
|
||||
**Goal**: Progressive computation budgets for complex tasks
|
||||
|
||||
**Modes:**
|
||||
- `think` - basic extended thinking
|
||||
- `think hard` - increased computation
|
||||
- `think harder` - more computation
|
||||
- `ultrathink` - maximum budget
|
||||
|
||||
**Implementation:**
|
||||
- Extend `owlen-core::types::ChatParameters`
|
||||
- Add thinking mode to TUI commands
|
||||
- Configure per-provider max tokens
|
||||
|
||||
**Files to modify:**
|
||||
- `crates/owlen-core/src/types.rs`
|
||||
- `crates/owlen-tui/src/command_parser.rs`
|
||||
|
||||
### Phase 19: Git Workflow Automation (MEDIUM PRIORITY)
|
||||
|
||||
**Goal**: Streamline common Git operations
|
||||
|
||||
**Features:**
|
||||
1. Auto-commit message generation
|
||||
2. PR creation via `gh` CLI
|
||||
3. Rebase conflict resolution
|
||||
4. File revert operations
|
||||
5. Git history analysis
|
||||
|
||||
**Implementation:**
|
||||
- New `owlen-mcp-git-server` crate
|
||||
- Tools: `commit`, `create_pr`, `rebase`, `revert`, `history`
|
||||
- Integration with TUI commands
|
||||
|
||||
**Files to create:**
|
||||
- `crates/owlen-mcp-git-server/src/lib.rs`
|
||||
|
||||
### Phase 20: Enterprise Features (LOW PRIORITY)
|
||||
|
||||
**Goal**: Team and enterprise deployment support
|
||||
|
||||
**Features:**
|
||||
1. **Managed Configuration**
|
||||
- `/etc/owlen/managed-mcp.json` (Linux)
|
||||
- Restrict user additions with `useEnterpriseMcpConfigOnly`
|
||||
|
||||
2. **Audit Logging**
|
||||
- Log all file writes and shell commands
|
||||
- Structured JSON logs
|
||||
- Tamper-proof storage
|
||||
|
||||
3. **Team Collaboration**
|
||||
- Shared OWLEN.md across team
|
||||
- Project-scoped MCP servers in `.mcp.json`
|
||||
- Approval policy enforcement
|
||||
|
||||
**Implementation:**
|
||||
- Extend `owlen-core::config` with managed settings
|
||||
- New `owlen-core::audit` module
|
||||
- Enterprise deployment documentation
|
||||
|
||||
## Testing Requirements
|
||||
|
||||
### Test Coverage Goals
|
||||
|
||||
- **Unit tests**: 80%+ coverage for `owlen-core`
|
||||
- **Integration tests**: All MCP servers, providers
|
||||
- **TUI tests**: Key workflows (not pixel-perfect)
|
||||
|
||||
### Test Organization
|
||||
|
||||
```rust
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::provider::test_utils::MockProvider;
|
||||
use crate::mcp::test_utils::MockMcpClient;
|
||||
|
||||
#[test]
|
||||
fn test_feature() {
|
||||
// Setup
|
||||
let provider = MockProvider::new();
|
||||
|
||||
// Execute
|
||||
let result = provider.chat(request).await;
|
||||
|
||||
// Assert
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
cargo test --all # All tests
|
||||
cargo test --lib -p owlen-core # Core library tests
|
||||
cargo test --test integration # Integration tests
|
||||
```
|
||||
|
||||
## Documentation Standards
|
||||
|
||||
### Code Documentation
|
||||
|
||||
1. **Module-level** (`//!` at top of file):
|
||||
```rust
|
||||
//! Brief module description
|
||||
//!
|
||||
//! Detailed explanation of module purpose,
|
||||
//! key types, and usage examples.
|
||||
```
|
||||
|
||||
2. **Public APIs** (`///` above items):
|
||||
```rust
|
||||
/// Brief description
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `arg1` - Description
|
||||
///
|
||||
/// # Returns
|
||||
/// Description of return value
|
||||
///
|
||||
/// # Errors
|
||||
/// When this function returns an error
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// let result = function(arg);
|
||||
/// ```
|
||||
pub fn function(arg: Type) -> Result<Output> {
|
||||
// implementation
|
||||
}
|
||||
```
|
||||
|
||||
3. **Private items**: Optional, use for complex logic
|
||||
|
||||
### User Documentation
|
||||
|
||||
Location: `docs/` directory
|
||||
|
||||
Files to maintain:
|
||||
- `architecture.md` - System design
|
||||
- `configuration.md` - Config reference
|
||||
- `migration-guide.md` - Version upgrades
|
||||
- `troubleshooting.md` - Common issues
|
||||
- `provider-implementation.md` - Adding providers
|
||||
- `faq.md` - Frequently asked questions
|
||||
|
||||
## Git Workflow
|
||||
|
||||
### Branch Strategy
|
||||
|
||||
- `main` - stable releases only
|
||||
- `dev` - active development (default)
|
||||
- `feature/*` - new features
|
||||
- `fix/*` - bug fixes
|
||||
- `docs/*` - documentation only
|
||||
|
||||
### Commit Messages
|
||||
|
||||
Follow conventional commits:
|
||||
|
||||
```
|
||||
type(scope): brief description
|
||||
|
||||
Detailed explanation of changes.
|
||||
|
||||
Breaking changes, if any.
|
||||
|
||||
🤖 Generated with [Claude Code](https://claude.com/claude-code)
|
||||
|
||||
Co-Authored-By: Claude <noreply@anthropic.com>
|
||||
```
|
||||
|
||||
Types: `feat`, `fix`, `docs`, `refactor`, `test`, `chore`
|
||||
|
||||
### Pre-commit Hooks
|
||||
|
||||
Automatically run:
|
||||
- `cargo fmt` (formatting)
|
||||
- `cargo check` (compilation)
|
||||
- `cargo clippy` (linting)
|
||||
- YAML/TOML validation
|
||||
- Trailing whitespace removal
|
||||
|
||||
## Performance Guidelines
|
||||
|
||||
### Optimization Priorities
|
||||
|
||||
1. **Startup time**: < 500ms cold start
|
||||
2. **First token latency**: < 2s for local models
|
||||
3. **Memory usage**: < 100MB base, < 500MB with conversation
|
||||
4. **Responsiveness**: TUI redraws < 16ms (60 FPS)
|
||||
|
||||
### Profiling
|
||||
|
||||
```bash
|
||||
cargo build --release --features profiling
|
||||
valgrind --tool=callgrind target/release/owlen
|
||||
kcachegrind callgrind.out.*
|
||||
```
|
||||
|
||||
### Async Performance
|
||||
|
||||
- Avoid blocking in async contexts
|
||||
- Use `tokio::spawn` for CPU-intensive work
|
||||
- Set timeouts on all network operations
|
||||
- Cancel tasks on shutdown
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Threat Model
|
||||
|
||||
**Trusted:**
|
||||
- User's local machine
|
||||
- User-installed Ollama models
|
||||
- User configuration files
|
||||
|
||||
**Untrusted:**
|
||||
- MCP server responses
|
||||
- Web search results
|
||||
- Code execution output
|
||||
- Cloud LLM responses
|
||||
|
||||
### Security Measures
|
||||
|
||||
1. **Input Validation**
|
||||
- Sanitize all MCP tool arguments
|
||||
- Validate JSON schemas strictly
|
||||
- Escape shell commands
|
||||
|
||||
2. **Sandboxing**
|
||||
- Docker for code execution
|
||||
- Network isolation
|
||||
- Filesystem restrictions
|
||||
|
||||
3. **Secrets Management**
|
||||
- Never log API keys
|
||||
- Use environment variables
|
||||
- Encrypt sensitive config fields
|
||||
|
||||
4. **Dependency Auditing**
|
||||
```bash
|
||||
cargo audit
|
||||
cargo deny check
|
||||
```
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
### Enable Debug Logging
|
||||
|
||||
```bash
|
||||
OWLEN_DEBUG_OLLAMA=1 owlen # Ollama requests
|
||||
RUST_LOG=debug owlen # All debug logs
|
||||
RUST_BACKTRACE=1 owlen # Stack traces
|
||||
```
|
||||
|
||||
### Common Issues
|
||||
|
||||
1. **Timeout on Ollama**
|
||||
- Check `ollama ps` for loaded models
|
||||
- Increase timeout in config
|
||||
- Restart Ollama service
|
||||
|
||||
2. **MCP Server Not Found**
|
||||
- Verify `mcp_servers` config
|
||||
- Check server binary exists
|
||||
- Test server manually with STDIO
|
||||
|
||||
3. **TUI Rendering Issues**
|
||||
- Test in different terminals
|
||||
- Check terminal size (`tput cols; tput lines`)
|
||||
- Verify theme compatibility
|
||||
|
||||
## Contributing
|
||||
|
||||
### Before Submitting PR
|
||||
|
||||
1. Run full test suite: `cargo test --all`
|
||||
2. Check formatting: `cargo fmt -- --check`
|
||||
3. Run linter: `cargo clippy -- -D warnings`
|
||||
4. Update documentation if API changed
|
||||
5. Add tests for new features
|
||||
6. Update CHANGELOG.md
|
||||
|
||||
### PR Description Template
|
||||
|
||||
```markdown
|
||||
## Summary
|
||||
Brief description of changes
|
||||
|
||||
## Type of Change
|
||||
- [ ] Bug fix
|
||||
- [ ] New feature
|
||||
- [ ] Breaking change
|
||||
- [ ] Documentation update
|
||||
|
||||
## Testing
|
||||
Describe tests performed
|
||||
|
||||
## Checklist
|
||||
- [ ] Tests added/updated
|
||||
- [ ] Documentation updated
|
||||
- [ ] CHANGELOG.md updated
|
||||
- [ ] No clippy warnings
|
||||
```
|
||||
|
||||
## Resources
|
||||
|
||||
### External Documentation
|
||||
|
||||
- [Ratatui Docs](https://ratatui.rs/)
|
||||
- [Tokio Tutorial](https://tokio.rs/tokio/tutorial)
|
||||
- [MCP Specification](https://modelcontextprotocol.io/)
|
||||
- [Ollama API](https://github.com/ollama/ollama/blob/main/docs/api.md)
|
||||
|
||||
### Internal Documentation
|
||||
|
||||
- `.agents/new_phases.md` - 10-phase migration plan (completed)
|
||||
- `docs/phase5-mode-system.md` - Mode system design
|
||||
- `docs/migration-guide.md` - v0.x → v1.0 migration
|
||||
|
||||
### Community
|
||||
|
||||
- GitHub Issues: Bug reports and feature requests
|
||||
- GitHub Discussions: Questions and ideas
|
||||
- AUR Package: `owlen-git` (Arch Linux)
|
||||
|
||||
## Version History
|
||||
|
||||
- **v1.0.0** (current) - MCP-only architecture, Phase 10 complete
|
||||
- **v0.2.0** - Added web search, code execution servers
|
||||
- **v0.1.0** - Initial release with Ollama support
|
||||
|
||||
## License
|
||||
|
||||
Owlen is open source software. See LICENSE file for details.
|
||||
|
||||
---
|
||||
|
||||
**Last Updated**: 2025-10-11
|
||||
**Maintained By**: Owlen Development Team
|
||||
**For AI Agents**: Follow these guidelines when modifying Owlen codebase. Prioritize MCP client enhancement (Phase 11) and approval system (Phase 12) for feature parity with Codex/Claude Code while maintaining local-first philosophy.
|
||||
114
CHANGELOG.md
114
CHANGELOG.md
@@ -1,114 +0,0 @@
|
||||
# Changelog
|
||||
|
||||
All notable changes to this project will be documented in this file.
|
||||
|
||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
- Comprehensive documentation suite including guides for architecture, configuration, testing, and more.
|
||||
- Rustdoc examples for core components like `Provider` and `SessionController`.
|
||||
- Module-level documentation for `owlen-tui`.
|
||||
- Provider integration tests (`crates/owlen-providers/tests`) covering registration, routing, and health status handling for the new `ProviderManager`.
|
||||
- TUI message and generation tests that exercise the non-blocking event loop, background worker, and message dispatch.
|
||||
- Ollama integration can now talk to Ollama Cloud when an API key is configured.
|
||||
- Ollama provider will also read `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables when no key is stored in the config.
|
||||
- `owlen config doctor`, `owlen config path`, and `owlen upgrade` CLI commands to automate migrations and surface manual update steps.
|
||||
- Startup provider health check with actionable hints when Ollama or remote MCP servers are unavailable.
|
||||
- `dev/check-windows.sh` helper script for on-demand Windows cross-checks.
|
||||
- Global F1 keybinding for the in-app help overlay and a clearer status hint on launch.
|
||||
- Automatic fallback to the new `ansi_basic` theme when the active terminal only advertises 16-color support.
|
||||
- Offline provider shim that keeps the TUI usable while primary providers are unreachable and communicates recovery steps inline.
|
||||
- `owlen cloud` subcommands (`setup`, `status`, `models`, `logout`) for managing Ollama Cloud credentials without hand-editing config files.
|
||||
- Tabbed model selector that separates local and cloud providers, including cloud indicators in the UI.
|
||||
- Footer status line includes provider connectivity/credential summaries (e.g., cloud auth failures, missing API keys).
|
||||
- Secure credential vault integration for Ollama Cloud API keys when `privacy.encrypt_local_data = true`.
|
||||
- Input panel respects a new `ui.input_max_rows` setting so long prompts expand predictably before scrolling kicks in.
|
||||
- Command palette offers fuzzy `:model` filtering and `:provider` completions for fast switching.
|
||||
- Message rendering caches wrapped lines and throttles streaming redraws to keep the TUI responsive on long sessions.
|
||||
- Model picker badges now inspect provider capabilities so vision/audio/thinking models surface the correct icons even when descriptions are sparse.
|
||||
- Chat history honors `ui.scrollback_lines`, trimming older rows to keep the TUI responsive and surfacing a "↓ New messages" badge whenever updates land off-screen.
|
||||
|
||||
### Changed
|
||||
- The main `README.md` has been updated to be more concise and link to the new documentation.
|
||||
- Default configuration now pre-populates both `providers.ollama` and `providers.ollama-cloud` entries so switching between local and cloud backends is a single setting change.
|
||||
- `McpMode` support was restored with explicit validation; `remote_only`, `remote_preferred`, and `local_only` now behave predictably.
|
||||
- Configuration loading performs structural validation and fails fast on missing default providers or invalid MCP definitions.
|
||||
- Ollama provider error handling now distinguishes timeouts, missing models, and authentication failures.
|
||||
- `owlen` warns when the active terminal likely lacks 256-color support.
|
||||
- `config.toml` now carries a schema version (`1.2.0`) and is migrated automatically; deprecated keys such as `agent.max_tool_calls` trigger warnings instead of hard failures.
|
||||
- Model selector navigation (Tab/Shift-Tab) now switches between local and cloud tabs while preserving selection state.
|
||||
- Header displays the active model together with its provider (e.g., `Model (Provider)`), improving clarity when swapping backends.
|
||||
- Documentation refreshed to cover the message handler architecture, the background health worker, multi-provider configuration, and the new provider onboarding checklist.
|
||||
|
||||
---
|
||||
|
||||
## [0.1.11] - 2025-10-18
|
||||
|
||||
### Changed
|
||||
- Bump workspace packages and distribution metadata to version `0.1.11`.
|
||||
|
||||
## [0.1.10] - 2025-10-03
|
||||
|
||||
### Added
|
||||
- **Material Light Theme**: A new built-in theme, `material-light`, has been added.
|
||||
|
||||
### Fixed
|
||||
- **UI Readability**: Fixed a bug causing unreadable text in light themes.
|
||||
- **Visual Selection**: The visual selection mode now correctly colors unselected text portions.
|
||||
|
||||
### Changed
|
||||
- **Theme Colors**: The color palettes for `gruvbox`, `rose-pine`, and `monokai` have been corrected.
|
||||
- **In-App Help**: The `:help` menu has been significantly expanded and updated.
|
||||
|
||||
## [0.1.9] - 2025-10-03
|
||||
|
||||
*This version corresponds to the release tagged v0.1.10 in the source repository.*
|
||||
|
||||
### Added
|
||||
- **Material Light Theme**: A new built-in theme, `material-light`, has been added.
|
||||
|
||||
### Fixed
|
||||
- **UI Readability**: Fixed a bug causing unreadable text in light themes.
|
||||
- **Visual Selection**: The visual selection mode now correctly colors unselected text portions.
|
||||
|
||||
### Changed
|
||||
- **Theme Colors**: The color palettes for `gruvbox`, `rose-pine`, and `monokai` have been corrected.
|
||||
- **In-App Help**: The `:help` menu has been significantly expanded and updated.
|
||||
|
||||
## [0.1.8] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Command Autocompletion**: Implemented intelligent command suggestions and Tab completion in command mode.
|
||||
|
||||
### Changed
|
||||
- **Build & CI**: Fixed cross-compilation for ARM64, ARMv7, and Windows.
|
||||
|
||||
## [0.1.7] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Tabbed Help System**: The help menu is now organized into five tabs for easier navigation.
|
||||
- **Command Aliases**: Added `:o` as a short alias for `:load` / `:open`.
|
||||
|
||||
### Changed
|
||||
- **Session Management**: Improved AI-generated session descriptions.
|
||||
|
||||
## [0.1.6] - 2025-10-02
|
||||
|
||||
### Added
|
||||
- **Platform-Specific Storage**: Sessions are now saved to platform-appropriate directories (e.g., `~/.local/share/owlen` on Linux).
|
||||
- **AI-Generated Session Descriptions**: Conversations can be automatically summarized on save.
|
||||
|
||||
### Changed
|
||||
- **Migration**: Users on older versions can manually move their sessions from `~/.config/owlen/sessions` to the new platform-specific directory.
|
||||
|
||||
## [0.1.4] - 2025-10-01
|
||||
|
||||
### Added
|
||||
- **Multi-Platform Builds**: Pre-built binaries are now provided for Linux (x86_64, aarch64, armv7) and Windows (x86_64).
|
||||
- **AUR Package**: Owlen is now available on the Arch User Repository.
|
||||
|
||||
### Changed
|
||||
- **Build System**: Switched from OpenSSL to rustls for better cross-platform compatibility.
|
||||
@@ -1,121 +0,0 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
We as members, contributors, and leaders pledge to make participation in our
|
||||
community a harassment-free experience for everyone, regardless of age, body
|
||||
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
||||
identity and expression, level of experience, education, socio-economic status,
|
||||
nationality, personal appearance, race, religion, or sexual identity
|
||||
and orientation.
|
||||
|
||||
We pledge to act and interact in ways that are welcoming, open, and respectful.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to a positive environment for our
|
||||
community include:
|
||||
|
||||
* Demonstrating empathy and kindness toward other people
|
||||
* Being respectful of differing opinions, viewpoints, and experiences
|
||||
* Giving and gracefully accepting constructive feedback
|
||||
* Accepting responsibility and apologizing to those affected by our mistakes,
|
||||
and learning from the experience
|
||||
* Focusing on what is best not just for us as individuals, but for the
|
||||
overall community
|
||||
|
||||
Examples of unacceptable behavior include:
|
||||
|
||||
* The use of sexualized language or imagery, and sexual attention or
|
||||
advances of any kind
|
||||
* Trolling, insulting or derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or email
|
||||
address, without their explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Enforcement Responsibilities
|
||||
|
||||
Community leaders are responsible for clarifying and enforcing our standards of
|
||||
acceptable behavior and will take appropriate and fair corrective action in
|
||||
response to any behavior that they deem inappropriate, threatening, offensive,
|
||||
or harmful.
|
||||
|
||||
Community leaders have the right and responsibility to remove, edit, or reject
|
||||
comments, commits, code, wiki edits, issues, and other contributions that are
|
||||
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
||||
decisions when appropriate.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies within all community spaces, and also applies when
|
||||
an individual is officially representing the community in public spaces.
|
||||
Examples of representing our community include using an official e-mail address,
|
||||
posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported to the community leaders responsible for enforcement at
|
||||
[security@owlibou.com](mailto:security@owlibou.com). All complaints will be
|
||||
reviewed and investigated promptly and fairly.
|
||||
|
||||
All community leaders are obligated to respect the privacy and security of the
|
||||
reporter of any incident.
|
||||
|
||||
## Enforcement Guidelines
|
||||
|
||||
Community leaders will follow these Community Impact Guidelines in determining
|
||||
the consequences for any action they deem in violation of this Code of Conduct:
|
||||
|
||||
### 1. Correction
|
||||
|
||||
**Community Impact**: Use of inappropriate language or other behavior deemed
|
||||
unprofessional or unwelcome in the community.
|
||||
|
||||
**Consequence**: A private, written warning from community leaders, providing
|
||||
clarity around the nature of the violation and an explanation of why the
|
||||
behavior was inappropriate. A public apology may be requested.
|
||||
|
||||
### 2. Warning
|
||||
|
||||
**Community Impact**: A violation through a single incident or series
|
||||
of actions.
|
||||
|
||||
**Consequence**: A warning with consequences for continued behavior. No
|
||||
interaction with the people involved, including unsolicited interaction with
|
||||
those enforcing the Code of Conduct, for a specified period of time. This
|
||||
includes avoiding interaction in community spaces as well as external channels
|
||||
like social media. Violating these terms may lead to a temporary or
|
||||
permanent ban.
|
||||
|
||||
### 3. Temporary Ban
|
||||
|
||||
**Community Impact**: A serious violation of community standards, including
|
||||
sustained inappropriate behavior.
|
||||
|
||||
**Consequence**: A temporary ban from any sort of interaction or public
|
||||
communication with the community for a specified period of time. No public or
|
||||
private interaction with the people involved, including unsolicited interaction
|
||||
with those enforcing the Code of Conduct, is allowed during this period.
|
||||
Violating these terms may lead to a permanent ban.
|
||||
|
||||
### 4. Permanent Ban
|
||||
|
||||
**Community Impact**: Demonstrating a pattern of violation of community
|
||||
standards, including sustained inappropriate behavior, harassment of an
|
||||
individual, or aggression toward or disparagement of classes of individuals.
|
||||
|
||||
**Consequence**: A permanent ban from any sort of public interaction within
|
||||
the community.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
||||
version 2.1, available at
|
||||
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
||||
126
CONTRIBUTING.md
126
CONTRIBUTING.md
@@ -1,126 +0,0 @@
|
||||
# Contributing to Owlen
|
||||
|
||||
First off, thank you for considering contributing to Owlen! It's people like you that make Owlen such a great tool.
|
||||
|
||||
Following these guidelines helps to communicate that you respect the time of the developers managing and developing this open source project. In return, they should reciprocate that respect in addressing your issue, assessing changes, and helping you finalize your pull requests.
|
||||
|
||||
## Code of Conduct
|
||||
|
||||
This project and everyone participating in it is governed by the [Owlen Code of Conduct](CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. Please report unacceptable behavior.
|
||||
|
||||
## How Can I Contribute?
|
||||
|
||||
### Repository map
|
||||
|
||||
Need a quick orientation before diving in? Start with the curated [repo map](docs/repo-map.md) for a two-level directory overview. If you move folders around, regenerate it with `scripts/gen-repo-map.sh`.
|
||||
|
||||
### Reporting Bugs
|
||||
|
||||
This is one of the most helpful ways you can contribute. Before creating a bug report, please check a few things:
|
||||
|
||||
1. **Check the [troubleshooting guide](docs/troubleshooting.md).** Your issue might be a common one with a known solution.
|
||||
2. **Search the existing issues.** It's possible someone has already reported the same bug. If so, add a comment to the existing issue instead of creating a new one.
|
||||
|
||||
When you are creating a bug report, please include as many details as possible. Fill out the required template, the information it asks for helps us resolve issues faster.
|
||||
|
||||
### Suggesting Enhancements
|
||||
|
||||
If you have an idea for a new feature or an improvement to an existing one, we'd love to hear about it. Please provide as much context as you can about what you're trying to achieve.
|
||||
|
||||
### Your First Code Contribution
|
||||
|
||||
Unsure where to begin contributing to Owlen? You can start by looking through `good first issue` and `help wanted` issues.
|
||||
|
||||
### Pull Requests
|
||||
|
||||
The process for submitting a pull request is as follows:
|
||||
|
||||
1. **Fork the repository** and create your branch from `main`.
|
||||
2. **Set up pre-commit hooks** (see [Development Setup](#development-setup) above). This will automatically format and lint your code.
|
||||
3. **Make your changes.**
|
||||
4. **Run the tests.**
|
||||
- `cargo test --all`
|
||||
5. **Commit your changes.** The pre-commit hooks will automatically run `cargo fmt`, `cargo check`, and `cargo clippy`. If you need to bypass the hooks (not recommended), use `git commit --no-verify`.
|
||||
6. **Add a clear, concise commit message.** We follow the [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) specification.
|
||||
7. **Push to your fork** and submit a pull request to Owlen's `main` branch.
|
||||
8. **Include a clear description** of the problem and solution. Include the relevant issue number if applicable.
|
||||
9. **Declare AI assistance.** If any part of the patch was generated with an AI tool (e.g., ChatGPT, Claude Code), call that out in the PR description. A human maintainer must review and approve AI-assisted changes before merge.
|
||||
|
||||
## Development Setup
|
||||
|
||||
To get started with the codebase, you'll need to have Rust installed. Then, you can clone the repository and build the project:
|
||||
|
||||
```sh
|
||||
git clone https://github.com/Owlibou/owlen.git
|
||||
cd owlen
|
||||
cargo build
|
||||
```
|
||||
|
||||
### Pre-commit Hooks
|
||||
|
||||
We use [pre-commit](https://pre-commit.com/) to automatically run formatting and linting checks before each commit. This helps maintain code quality and consistency.
|
||||
|
||||
**Install pre-commit:**
|
||||
|
||||
```sh
|
||||
# Arch Linux
|
||||
sudo pacman -S pre-commit
|
||||
|
||||
# Other Linux/macOS
|
||||
pip install pre-commit
|
||||
|
||||
# Verify installation
|
||||
pre-commit --version
|
||||
```
|
||||
|
||||
**Setup the hooks:**
|
||||
|
||||
```sh
|
||||
cd owlen
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
Once installed, the hooks will automatically run on every commit. You can also run them manually:
|
||||
|
||||
```sh
|
||||
# Run on all files
|
||||
pre-commit run --all-files
|
||||
|
||||
# Run on staged files only
|
||||
pre-commit run
|
||||
```
|
||||
|
||||
The pre-commit hooks will check:
|
||||
- Code formatting (`cargo fmt`)
|
||||
- Compilation (`cargo check`)
|
||||
- Linting (`cargo clippy --all-features`)
|
||||
- General file hygiene (trailing whitespace, EOF newlines, etc.)
|
||||
|
||||
## Coding Style
|
||||
|
||||
- We use `cargo fmt` for automated code formatting. Please run it before committing your changes.
|
||||
- We use `cargo clippy` for linting. Your code should be free of any clippy warnings.
|
||||
|
||||
## Commit Message Conventions
|
||||
|
||||
We use [Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/) for our commit messages. This allows for automated changelog generation and makes the project history easier to read.
|
||||
|
||||
The basic format is:
|
||||
|
||||
```
|
||||
<type>[optional scope]: <description>
|
||||
|
||||
[optional body]
|
||||
|
||||
[optional footer(s)]
|
||||
```
|
||||
|
||||
**Types:** `feat`, `fix`, `docs`, `style`, `refactor`, `test`, `chore`, `build`, `ci`.
|
||||
|
||||
**Example:**
|
||||
|
||||
```
|
||||
feat(provider): add support for Gemini Pro
|
||||
```
|
||||
|
||||
Thank you for your contribution!
|
||||
103
Cargo.toml
103
Cargo.toml
@@ -1,86 +1,31 @@
|
||||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/owlen-core",
|
||||
"crates/owlen-tui",
|
||||
"crates/owlen-cli",
|
||||
"crates/owlen-providers",
|
||||
"crates/mcp/server",
|
||||
"crates/mcp/llm-server",
|
||||
"crates/mcp/client",
|
||||
"crates/mcp/code-server",
|
||||
"crates/mcp/prompt-server",
|
||||
"crates/owlen-markdown",
|
||||
"xtask",
|
||||
"crates/app/cli",
|
||||
"crates/app/ui",
|
||||
"crates/core/agent",
|
||||
"crates/llm/core",
|
||||
"crates/llm/anthropic",
|
||||
"crates/llm/ollama",
|
||||
"crates/llm/openai",
|
||||
"crates/platform/config",
|
||||
"crates/platform/hooks",
|
||||
"crates/platform/permissions",
|
||||
"crates/platform/plugins",
|
||||
"crates/tools/ask",
|
||||
"crates/tools/bash",
|
||||
"crates/tools/fs",
|
||||
"crates/tools/notebook",
|
||||
"crates/tools/plan",
|
||||
"crates/tools/skill",
|
||||
"crates/tools/slash",
|
||||
"crates/tools/task",
|
||||
"crates/tools/todo",
|
||||
"crates/tools/web",
|
||||
"crates/integration/mcp-client",
|
||||
]
|
||||
exclude = []
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.1.11"
|
||||
edition = "2024"
|
||||
authors = ["Owlibou"]
|
||||
license = "AGPL-3.0"
|
||||
repository = "https://somegit.dev/Owlibou/owlen"
|
||||
homepage = "https://somegit.dev/Owlibou/owlen"
|
||||
keywords = ["llm", "tui", "cli", "ollama", "chat"]
|
||||
categories = ["command-line-utilities"]
|
||||
|
||||
[workspace.dependencies]
|
||||
# Async runtime and utilities
|
||||
tokio = { version = "1.0", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
tokio-util = { version = "0.7", features = ["rt"] }
|
||||
futures = "0.3"
|
||||
futures-util = "0.3"
|
||||
|
||||
# TUI framework
|
||||
ratatui = "0.28"
|
||||
crossterm = "0.28"
|
||||
tui-textarea = "0.6"
|
||||
|
||||
# HTTP client and JSON handling
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "rustls-tls"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = { version = "1.0" }
|
||||
|
||||
# Utilities
|
||||
uuid = { version = "1.0", features = ["v4", "serde"] }
|
||||
anyhow = "1.0"
|
||||
thiserror = "2.0"
|
||||
nix = "0.29"
|
||||
which = "6.0"
|
||||
tempfile = "3.8"
|
||||
jsonschema = "0.17"
|
||||
aes-gcm = "0.10"
|
||||
ring = "0.17"
|
||||
keyring = "3.0"
|
||||
chrono = { version = "0.4", features = ["serde"] }
|
||||
urlencoding = "2.1"
|
||||
regex = "1.10"
|
||||
rpassword = "7.3"
|
||||
sqlx = { version = "0.7", default-features = false, features = ["runtime-tokio-rustls", "sqlite", "macros", "uuid", "chrono", "migrate"] }
|
||||
log = "0.4"
|
||||
dirs = "5.0"
|
||||
serde_yaml = "0.9"
|
||||
handlebars = "6.0"
|
||||
|
||||
# Configuration
|
||||
toml = "0.8"
|
||||
shellexpand = "3.1"
|
||||
|
||||
# Database
|
||||
sled = "0.34"
|
||||
|
||||
# For better text handling
|
||||
textwrap = "0.16"
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# CLI framework
|
||||
clap = { version = "4.0", features = ["derive"] }
|
||||
|
||||
# Dev dependencies
|
||||
tokio-test = "0.4"
|
||||
|
||||
# For more keys and their definitions, see https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
rust-version = "1.91"
|
||||
|
||||
661
LICENSE
661
LICENSE
@@ -1,661 +0,0 @@
|
||||
GNU AFFERO GENERAL PUBLIC LICENSE
|
||||
Version 3, 19 November 2007
|
||||
|
||||
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
|
||||
Everyone is permitted to copy and distribute verbatim copies
|
||||
of this license document, but changing it is not allowed.
|
||||
|
||||
Preamble
|
||||
|
||||
The GNU Affero General Public License is a free, copyleft license for
|
||||
software and other kinds of works, specifically designed to ensure
|
||||
cooperation with the community in the case of network server software.
|
||||
|
||||
The licenses for most software and other practical works are designed
|
||||
to take away your freedom to share and change the works. By contrast,
|
||||
our General Public Licenses are intended to guarantee your freedom to
|
||||
share and change all versions of a program--to make sure it remains free
|
||||
software for all its users.
|
||||
|
||||
When we speak of free software, we are referring to freedom, not
|
||||
price. Our General Public Licenses are designed to make sure that you
|
||||
have the freedom to distribute copies of free software (and charge for
|
||||
them if you wish), that you receive source code or can get it if you
|
||||
want it, that you can change the software or use pieces of it in new
|
||||
free programs, and that you know you can do these things.
|
||||
|
||||
Developers that use our General Public Licenses protect your rights
|
||||
with two steps: (1) assert copyright on the software, and (2) offer
|
||||
you this License which gives you legal permission to copy, distribute
|
||||
and/or modify the software.
|
||||
|
||||
A secondary benefit of defending all users' freedom is that
|
||||
improvements made in alternate versions of the program, if they
|
||||
receive widespread use, become available for other developers to
|
||||
incorporate. Many developers of free software are heartened and
|
||||
encouraged by the resulting cooperation. However, in the case of
|
||||
software used on network servers, this result may fail to come about.
|
||||
The GNU General Public License permits making a modified version and
|
||||
letting the public access it on a server without ever releasing its
|
||||
source code to the public.
|
||||
|
||||
The GNU Affero General Public License is designed specifically to
|
||||
ensure that, in such cases, the modified source code becomes available
|
||||
to the community. It requires the operator of a network server to
|
||||
provide the source code of the modified version running there to the
|
||||
users of that server. Therefore, public use of a modified version, on
|
||||
a publicly accessible server, gives the public access to the source
|
||||
code of the modified version.
|
||||
|
||||
An older license, called the Affero General Public License and
|
||||
published by Affero, was designed to accomplish similar goals. This is
|
||||
a different license, not a version of the Affero GPL, but Affero has
|
||||
released a new version of the Affero GPL which permits relicensing under
|
||||
this license.
|
||||
|
||||
The precise terms and conditions for copying, distribution and
|
||||
modification follow.
|
||||
|
||||
TERMS AND CONDITIONS
|
||||
|
||||
0. Definitions.
|
||||
|
||||
"This License" refers to version 3 of the GNU Affero General Public License.
|
||||
|
||||
"Copyright" also means copyright-like laws that apply to other kinds of
|
||||
works, such as semiconductor masks.
|
||||
|
||||
"The Program" refers to any copyrightable work licensed under this
|
||||
License. Each licensee is addressed as "you". "Licensees" and
|
||||
"recipients" may be individuals or organizations.
|
||||
|
||||
To "modify" a work means to copy from or adapt all or part of the work
|
||||
in a fashion requiring copyright permission, other than the making of an
|
||||
exact copy. The resulting work is called a "modified version" of the
|
||||
earlier work or a work "based on" the earlier work.
|
||||
|
||||
A "covered work" means either the unmodified Program or a work based
|
||||
on the Program.
|
||||
|
||||
To "propagate" a work means to do anything with it that, without
|
||||
permission, would make you directly or secondarily liable for
|
||||
infringement under applicable copyright law, except executing it on a
|
||||
computer or modifying a private copy. Propagation includes copying,
|
||||
distribution (with or without modification), making available to the
|
||||
public, and in some countries other activities as well.
|
||||
|
||||
To "convey" a work means any kind of propagation that enables other
|
||||
parties to make or receive copies. Mere interaction with a user through
|
||||
a computer network, with no transfer of a copy, is not conveying.
|
||||
|
||||
An interactive user interface displays "Appropriate Legal Notices"
|
||||
to the extent that it includes a convenient and prominently visible
|
||||
feature that (1) displays an appropriate copyright notice, and (2)
|
||||
tells the user that there is no warranty for the work (except to the
|
||||
extent that warranties are provided), that licensees may convey the
|
||||
work under this License, and how to view a copy of this License. If
|
||||
the interface presents a list of user commands or options, such as a
|
||||
menu, a prominent item in the list meets this criterion.
|
||||
|
||||
1. Source Code.
|
||||
|
||||
The "source code" for a work means the preferred form of the work
|
||||
for making modifications to it. "Object code" means any non-source
|
||||
form of a work.
|
||||
|
||||
A "Standard Interface" means an interface that either is an official
|
||||
standard defined by a recognized standards body, or, in the case of
|
||||
interfaces specified for a particular programming language, one that
|
||||
is widely used among developers working in that language.
|
||||
|
||||
The "System Libraries" of an executable work include anything, other
|
||||
than the work as a whole, that (a) is included in the normal form of
|
||||
packaging a Major Component, but which is not part of that Major
|
||||
Component, and (b) serves only to enable use of the work with that
|
||||
Major Component, or to implement a Standard Interface for which an
|
||||
implementation is available to the public in source code form. A
|
||||
"Major Component", in this context, means a major essential component
|
||||
(kernel, window system, and so on) of the specific operating system
|
||||
(if any) on which the executable work runs, or a compiler used to
|
||||
produce the work, or an object code interpreter used to run it.
|
||||
|
||||
The "Corresponding Source" for a work in object code form means all
|
||||
the source code needed to generate, install, and (for an executable
|
||||
work) run the object code and to modify the work, including scripts to
|
||||
control those activities. However, it does not include the work's
|
||||
System Libraries, or general-purpose tools or generally available free
|
||||
programs which are used unmodified in performing those activities but
|
||||
which are not part of the work. For example, Corresponding Source
|
||||
includes interface definition files associated with source files for
|
||||
the work, and the source code for shared libraries and dynamically
|
||||
linked subprograms that the work is specifically designed to require,
|
||||
such as by intimate data communication or control flow between those
|
||||
subprograms and other parts of the work.
|
||||
|
||||
The Corresponding Source need not include anything that users
|
||||
can regenerate automatically from other parts of the Corresponding
|
||||
Source.
|
||||
|
||||
The Corresponding Source for a work in source code form is that
|
||||
same work.
|
||||
|
||||
2. Basic Permissions.
|
||||
|
||||
All rights granted under this License are granted for the term of
|
||||
copyright on the Program, and are irrevocable provided the stated
|
||||
conditions are met. This License explicitly affirms your unlimited
|
||||
permission to run the unmodified Program. The output from running a
|
||||
covered work is covered by this License only if the output, given its
|
||||
content, constitutes a covered work. This License acknowledges your
|
||||
rights of fair use or other equivalent, as provided by copyright law.
|
||||
|
||||
You may make, run and propagate covered works that you do not
|
||||
convey, without conditions so long as your license otherwise remains
|
||||
in force. You may convey covered works to others for the sole purpose
|
||||
of having them make modifications exclusively for you, or provide you
|
||||
with facilities for running those works, provided that you comply with
|
||||
the terms of this License in conveying all material for which you do
|
||||
not control copyright. Those thus making or running the covered works
|
||||
for you must do so exclusively on your behalf, under your direction
|
||||
and control, on terms that prohibit them from making any copies of
|
||||
your copyrighted material outside their relationship with you.
|
||||
|
||||
Conveying under any other circumstances is permitted solely under
|
||||
the conditions stated below. Sublicensing is not allowed; section 10
|
||||
makes it unnecessary.
|
||||
|
||||
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
|
||||
|
||||
No covered work shall be deemed part of an effective technological
|
||||
measure under any applicable law fulfilling obligations under article
|
||||
11 of the WIPO copyright treaty adopted on 20 December 1996, or
|
||||
similar laws prohibiting or restricting circumvention of such
|
||||
measures.
|
||||
|
||||
When you convey a covered work, you waive any legal power to forbid
|
||||
circumvention of technological measures to the extent such circumvention
|
||||
is effected by exercising rights under this License with respect to
|
||||
the covered work, and you disclaim any intention to limit operation or
|
||||
modification of the work as a means of enforcing, against the work's
|
||||
users, your or third parties' legal rights to forbid circumvention of
|
||||
technological measures.
|
||||
|
||||
4. Conveying Verbatim Copies.
|
||||
|
||||
You may convey verbatim copies of the Program's source code as you
|
||||
receive it, in any medium, provided that you conspicuously and
|
||||
appropriately publish on each copy an appropriate copyright notice;
|
||||
keep intact all notices stating that this License and any
|
||||
non-permissive terms added in accord with section 7 apply to the code;
|
||||
keep intact all notices of the absence of any warranty; and give all
|
||||
recipients a copy of this License along with the Program.
|
||||
|
||||
You may charge any price or no price for each copy that you convey,
|
||||
and you may offer support or warranty protection for a fee.
|
||||
|
||||
5. Conveying Modified Source Versions.
|
||||
|
||||
You may convey a work based on the Program, or the modifications to
|
||||
produce it from the Program, in the form of source code under the
|
||||
terms of section 4, provided that you also meet all of these conditions:
|
||||
|
||||
a) The work must carry prominent notices stating that you modified
|
||||
it, and giving a relevant date.
|
||||
|
||||
b) The work must carry prominent notices stating that it is
|
||||
released under this License and any conditions added under section
|
||||
7. This requirement modifies the requirement in section 4 to
|
||||
"keep intact all notices".
|
||||
|
||||
c) You must license the entire work, as a whole, under this
|
||||
License to anyone who comes into possession of a copy. This
|
||||
License will therefore apply, along with any applicable section 7
|
||||
additional terms, to the whole of the work, and all its parts,
|
||||
regardless of how they are packaged. This License gives no
|
||||
permission to license the work in any other way, but it does not
|
||||
invalidate such permission if you have separately received it.
|
||||
|
||||
d) If the work has interactive user interfaces, each must display
|
||||
Appropriate Legal Notices; however, if the Program has interactive
|
||||
interfaces that do not display Appropriate Legal Notices, your
|
||||
work need not make them do so.
|
||||
|
||||
A compilation of a covered work with other separate and independent
|
||||
works, which are not by their nature extensions of the covered work,
|
||||
and which are not combined with it such as to form a larger program,
|
||||
in or on a volume of a storage or distribution medium, is called an
|
||||
"aggregate" if the compilation and its resulting copyright are not
|
||||
used to limit the access or legal rights of the compilation's users
|
||||
beyond what the individual works permit. Inclusion of a covered work
|
||||
in an aggregate does not cause this License to apply to the other
|
||||
parts of the aggregate.
|
||||
|
||||
6. Conveying Non-Source Forms.
|
||||
|
||||
You may convey a covered work in object code form under the terms
|
||||
of sections 4 and 5, provided that you also convey the
|
||||
machine-readable Corresponding Source under the terms of this License,
|
||||
in one of these ways:
|
||||
|
||||
a) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by the
|
||||
Corresponding Source fixed on a durable physical medium
|
||||
customarily used for software interchange.
|
||||
|
||||
b) Convey the object code in, or embodied in, a physical product
|
||||
(including a physical distribution medium), accompanied by a
|
||||
written offer, valid for at least three years and valid for as
|
||||
long as you offer spare parts or customer support for that product
|
||||
model, to give anyone who possesses the object code either (1) a
|
||||
copy of the Corresponding Source for all the software in the
|
||||
product that is covered by this License, on a durable physical
|
||||
medium customarily used for software interchange, for a price no
|
||||
more than your reasonable cost of physically performing this
|
||||
conveying of source, or (2) access to copy the
|
||||
Corresponding Source from a network server at no charge.
|
||||
|
||||
c) Convey individual copies of the object code with a copy of the
|
||||
written offer to provide the Corresponding Source. This
|
||||
alternative is allowed only occasionally and noncommercially, and
|
||||
only if you received the object code with such an offer, in accord
|
||||
with subsection 6b.
|
||||
|
||||
d) Convey the object code by offering access from a designated
|
||||
place (gratis or for a charge), and offer equivalent access to the
|
||||
Corresponding Source in the same way through the same place at no
|
||||
further charge. You need not require recipients to copy the
|
||||
Corresponding Source along with the object code. If the place to
|
||||
copy the object code is a network server, the Corresponding Source
|
||||
may be on a different server (operated by you or a third party)
|
||||
that supports equivalent copying facilities, provided you maintain
|
||||
clear directions next to the object code saying where to find the
|
||||
Corresponding Source. Regardless of what server hosts the
|
||||
Corresponding Source, you remain obligated to ensure that it is
|
||||
available for as long as needed to satisfy these requirements.
|
||||
|
||||
e) Convey the object code using peer-to-peer transmission, provided
|
||||
you inform other peers where the object code and Corresponding
|
||||
Source of the work are being offered to the general public at no
|
||||
charge under subsection 6d.
|
||||
|
||||
A separable portion of the object code, whose source code is excluded
|
||||
from the Corresponding Source as a System Library, need not be
|
||||
included in conveying the object code work.
|
||||
|
||||
A "User Product" is either (1) a "consumer product", which means any
|
||||
tangible personal property which is normally used for personal, family,
|
||||
or household purposes, or (2) anything designed or sold for incorporation
|
||||
into a dwelling. In determining whether a product is a consumer product,
|
||||
doubtful cases shall be resolved in favor of coverage. For a particular
|
||||
product received by a particular user, "normally used" refers to a
|
||||
typical or common use of that class of product, regardless of the status
|
||||
of the particular user or of the way in which the particular user
|
||||
actually uses, or expects or is expected to use, the product. A product
|
||||
is a consumer product regardless of whether the product has substantial
|
||||
commercial, industrial or non-consumer uses, unless such uses represent
|
||||
the only significant mode of use of the product.
|
||||
|
||||
"Installation Information" for a User Product means any methods,
|
||||
procedures, authorization keys, or other information required to install
|
||||
and execute modified versions of a covered work in that User Product from
|
||||
a modified version of its Corresponding Source. The information must
|
||||
suffice to ensure that the continued functioning of the modified object
|
||||
code is in no case prevented or interfered with solely because
|
||||
modification has been made.
|
||||
|
||||
If you convey an object code work under this section in, or with, or
|
||||
specifically for use in, a User Product, and the conveying occurs as
|
||||
part of a transaction in which the right of possession and use of the
|
||||
User Product is transferred to the recipient in perpetuity or for a
|
||||
fixed term (regardless of how the transaction is characterized), the
|
||||
Corresponding Source conveyed under this section must be accompanied
|
||||
by the Installation Information. But this requirement does not apply
|
||||
if neither you nor any third party retains the ability to install
|
||||
modified object code on the User Product (for example, the work has
|
||||
been installed in ROM).
|
||||
|
||||
The requirement to provide Installation Information does not include a
|
||||
requirement to continue to provide support service, warranty, or updates
|
||||
for a work that has been modified or installed by the recipient, or for
|
||||
the User Product in which it has been modified or installed. Access to a
|
||||
network may be denied when the modification itself materially and
|
||||
adversely affects the operation of the network or violates the rules and
|
||||
protocols for communication across the network.
|
||||
|
||||
Corresponding Source conveyed, and Installation Information provided,
|
||||
in accord with this section must be in a format that is publicly
|
||||
documented (and with an implementation available to the public in
|
||||
source code form), and must require no special password or key for
|
||||
unpacking, reading or copying.
|
||||
|
||||
7. Additional Terms.
|
||||
|
||||
"Additional permissions" are terms that supplement the terms of this
|
||||
License by making exceptions from one or more of its conditions.
|
||||
Additional permissions that are applicable to the entire Program shall
|
||||
be treated as though they were included in this License, to the extent
|
||||
that they are valid under applicable law. If additional permissions
|
||||
apply only to part of the Program, that part may be used separately
|
||||
under those permissions, but the entire Program remains governed by
|
||||
this License without regard to the additional permissions.
|
||||
|
||||
When you convey a copy of a covered work, you may at your option
|
||||
remove any additional permissions from that copy, or from any part of
|
||||
it. (Additional permissions may be written to require their own
|
||||
removal in certain cases when you modify the work.) You may place
|
||||
additional permissions on material, added by you to a covered work,
|
||||
for which you have or can give appropriate copyright permission.
|
||||
|
||||
Notwithstanding any other provision of this License, for material you
|
||||
add to a covered work, you may (if authorized by the copyright holders of
|
||||
that material) supplement the terms of this License with terms:
|
||||
|
||||
a) Disclaiming warranty or limiting liability differently from the
|
||||
terms of sections 15 and 16 of this License; or
|
||||
|
||||
b) Requiring preservation of specified reasonable legal notices or
|
||||
author attributions in that material or in the Appropriate Legal
|
||||
Notices displayed by works containing it; or
|
||||
|
||||
c) Prohibiting misrepresentation of the origin of that material, or
|
||||
requiring that modified versions of such material be marked in
|
||||
reasonable ways as different from the original version; or
|
||||
|
||||
d) Limiting the use for publicity purposes of names of licensors or
|
||||
authors of the material; or
|
||||
|
||||
e) Declining to grant rights under trademark law for use of some
|
||||
trade names, trademarks, or service marks; or
|
||||
|
||||
f) Requiring indemnification of licensors and authors of that
|
||||
material by anyone who conveys the material (or modified versions of
|
||||
it) with contractual assumptions of liability to the recipient, for
|
||||
any liability that these contractual assumptions directly impose on
|
||||
those licensors and authors.
|
||||
|
||||
All other non-permissive additional terms are considered "further
|
||||
restrictions" within the meaning of section 10. If the Program as you
|
||||
received it, or any part of it, contains a notice stating that it is
|
||||
governed by this License along with a term that is a further
|
||||
restriction, you may remove that term. If a license document contains
|
||||
a further restriction but permits relicensing or conveying under this
|
||||
License, you may add to a covered work material governed by the terms
|
||||
of that license document, provided that the further restriction does
|
||||
not survive such relicensing or conveying.
|
||||
|
||||
If you add terms to a covered work in accord with this section, you
|
||||
must place, in the relevant source files, a statement of the
|
||||
additional terms that apply to those files, or a notice indicating
|
||||
where to find the applicable terms.
|
||||
|
||||
Additional terms, permissive or non-permissive, may be stated in the
|
||||
form of a separately written license, or stated as exceptions;
|
||||
the above requirements apply either way.
|
||||
|
||||
8. Termination.
|
||||
|
||||
You may not propagate or modify a covered work except as expressly
|
||||
provided under this License. Any attempt otherwise to propagate or
|
||||
modify it is void, and will automatically terminate your rights under
|
||||
this License (including any patent licenses granted under the third
|
||||
paragraph of section 11).
|
||||
|
||||
However, if you cease all violation of this License, then your
|
||||
license from a particular copyright holder is reinstated (a)
|
||||
provisionally, unless and until the copyright holder explicitly and
|
||||
finally terminates your license, and (b) permanently, if the copyright
|
||||
holder fails to notify you of the violation by some reasonable means
|
||||
prior to 60 days after the cessation.
|
||||
|
||||
Moreover, your license from a particular copyright holder is
|
||||
reinstated permanently if the copyright holder notifies you of the
|
||||
violation by some reasonable means, this is the first time you have
|
||||
received notice of violation of this License (for any work) from that
|
||||
copyright holder, and you cure the violation prior to 30 days after
|
||||
your receipt of the notice.
|
||||
|
||||
Termination of your rights under this section does not terminate the
|
||||
licenses of parties who have received copies or rights from you under
|
||||
this License. If your rights have been terminated and not permanently
|
||||
reinstated, you do not qualify to receive new licenses for the same
|
||||
material under section 10.
|
||||
|
||||
9. Acceptance Not Required for Having Copies.
|
||||
|
||||
You are not required to accept this License in order to receive or
|
||||
run a copy of the Program. Ancillary propagation of a covered work
|
||||
occurring solely as a consequence of using peer-to-peer transmission
|
||||
to receive a copy likewise does not require acceptance. However,
|
||||
nothing other than this License grants you permission to propagate or
|
||||
modify any covered work. These actions infringe copyright if you do
|
||||
not accept this License. Therefore, by modifying or propagating a
|
||||
covered work, you indicate your acceptance of this License to do so.
|
||||
|
||||
10. Automatic Licensing of Downstream Recipients.
|
||||
|
||||
Each time you convey a covered work, the recipient automatically
|
||||
receives a license from the original licensors, to run, modify and
|
||||
propagate that work, subject to this License. You are not responsible
|
||||
for enforcing compliance by third parties with this License.
|
||||
|
||||
An "entity transaction" is a transaction transferring control of an
|
||||
organization, or substantially all assets of one, or subdividing an
|
||||
organization, or merging organizations. If propagation of a covered
|
||||
work results from an entity transaction, each party to that
|
||||
transaction who receives a copy of the work also receives whatever
|
||||
licenses to the work the party's predecessor in interest had or could
|
||||
give under the previous paragraph, plus a right to possession of the
|
||||
Corresponding Source of the work from the predecessor in interest, if
|
||||
the predecessor has it or can get it with reasonable efforts.
|
||||
|
||||
You may not impose any further restrictions on the exercise of the
|
||||
rights granted or affirmed under this License. For example, you may
|
||||
not impose a license fee, royalty, or other charge for exercise of
|
||||
rights granted under this License, and you may not initiate litigation
|
||||
(including a cross-claim or counterclaim in a lawsuit) alleging that
|
||||
any patent claim is infringed by making, using, selling, offering for
|
||||
sale, or importing the Program or any portion of it.
|
||||
|
||||
11. Patents.
|
||||
|
||||
A "contributor" is a copyright holder who authorizes use under this
|
||||
License of the Program or a work on which the Program is based. The
|
||||
work thus licensed is called the contributor's "contributor version".
|
||||
|
||||
A contributor's "essential patent claims" are all patent claims
|
||||
owned or controlled by the contributor, whether already acquired or
|
||||
hereafter acquired, that would be infringed by some manner, permitted
|
||||
by this License, of making, using, or selling its contributor version,
|
||||
but do not include claims that would be infringed only as a
|
||||
consequence of further modification of the contributor version. For
|
||||
purposes of this definition, "control" includes the right to grant
|
||||
patent sublicenses in a manner consistent with the requirements of
|
||||
this License.
|
||||
|
||||
Each contributor grants you a non-exclusive, worldwide, royalty-free
|
||||
patent license under the contributor's essential patent claims, to
|
||||
make, use, sell, offer for sale, import and otherwise run, modify and
|
||||
propagate the contents of its contributor version.
|
||||
|
||||
In the following three paragraphs, a "patent license" is any express
|
||||
agreement or commitment, however denominated, not to enforce a patent
|
||||
(such as an express permission to practice a patent or covenant not to
|
||||
sue for patent infringement). To "grant" such a patent license to a
|
||||
party means to make such an agreement or commitment not to enforce a
|
||||
patent against the party.
|
||||
|
||||
If you convey a covered work, knowingly relying on a patent license,
|
||||
and the Corresponding Source of the work is not available for anyone
|
||||
to copy, free of charge and under the terms of this License, through a
|
||||
publicly available network server or other readily accessible means,
|
||||
then you must either (1) cause the Corresponding Source to be so
|
||||
available, or (2) arrange to deprive yourself of the benefit of the
|
||||
patent license for this particular work, or (3) arrange, in a manner
|
||||
consistent with the requirements of this License, to extend the patent
|
||||
license to downstream recipients. "Knowingly relying" means you have
|
||||
actual knowledge that, but for the patent license, your conveying the
|
||||
covered work in a country, or your recipient's use of the covered work
|
||||
in a country, would infringe one or more identifiable patents in that
|
||||
country that you have reason to believe are valid.
|
||||
|
||||
If, pursuant to or in connection with a single transaction or
|
||||
arrangement, you convey, or propagate by procuring conveyance of, a
|
||||
covered work, and grant a patent license to some of the parties
|
||||
receiving the covered work authorizing them to use, propagate, modify
|
||||
or convey a specific copy of the covered work, then the patent license
|
||||
you grant is automatically extended to all recipients of the covered
|
||||
work and works based on it.
|
||||
|
||||
A patent license is "discriminatory" if it does not include within
|
||||
the scope of its coverage, prohibits the exercise of, or is
|
||||
conditioned on the non-exercise of one or more of the rights that are
|
||||
specifically granted under this License. You may not convey a covered
|
||||
work if you are a party to an arrangement with a third party that is
|
||||
in the business of distributing software, under which you make payment
|
||||
to the third party based on the extent of your activity of conveying
|
||||
the work, and under which the third party grants, to any of the
|
||||
parties who would receive the covered work from you, a discriminatory
|
||||
patent license (a) in connection with copies of the covered work
|
||||
conveyed by you (or copies made from those copies), or (b) primarily
|
||||
for and in connection with specific products or compilations that
|
||||
contain the covered work, unless you entered into that arrangement,
|
||||
or that patent license was granted, prior to 28 March 2007.
|
||||
|
||||
Nothing in this License shall be construed as excluding or limiting
|
||||
any implied license or other defenses to infringement that may
|
||||
otherwise be available to you under applicable patent law.
|
||||
|
||||
12. No Surrender of Others' Freedom.
|
||||
|
||||
If conditions are imposed on you (whether by court order, agreement or
|
||||
otherwise) that contradict the conditions of this License, they do not
|
||||
excuse you from the conditions of this License. If you cannot convey a
|
||||
covered work so as to satisfy simultaneously your obligations under this
|
||||
License and any other pertinent obligations, then as a consequence you may
|
||||
not convey it at all. For example, if you agree to terms that obligate you
|
||||
to collect a royalty for further conveying from those to whom you convey
|
||||
the Program, the only way you could satisfy both those terms and this
|
||||
License would be to refrain entirely from conveying the Program.
|
||||
|
||||
13. Remote Network Interaction; Use with the GNU General Public License.
|
||||
|
||||
Notwithstanding any other provision of this License, if you modify the
|
||||
Program, your modified version must prominently offer all users
|
||||
interacting with it remotely through a computer network (if your version
|
||||
supports such interaction) an opportunity to receive the Corresponding
|
||||
Source of your version by providing access to the Corresponding Source
|
||||
from a network server at no charge, through some standard or customary
|
||||
means of facilitating copying of software. This Corresponding Source
|
||||
shall include the Corresponding Source for any work covered by version 3
|
||||
of the GNU General Public License that is incorporated pursuant to the
|
||||
following paragraph.
|
||||
|
||||
Notwithstanding any other provision of this License, you have
|
||||
permission to link or combine any covered work with a work licensed
|
||||
under version 3 of the GNU General Public License into a single
|
||||
combined work, and to convey the resulting work. The terms of this
|
||||
License will continue to apply to the part which is the covered work,
|
||||
but the work with which it is combined will remain governed by version
|
||||
3 of the GNU General Public License.
|
||||
|
||||
14. Revised Versions of this License.
|
||||
|
||||
The Free Software Foundation may publish revised and/or new versions of
|
||||
the GNU Affero General Public License from time to time. Such new versions
|
||||
will be similar in spirit to the present version, but may differ in detail to
|
||||
address new problems or concerns.
|
||||
|
||||
Each version is given a distinguishing version number. If the
|
||||
Program specifies that a certain numbered version of the GNU Affero General
|
||||
Public License "or any later version" applies to it, you have the
|
||||
option of following the terms and conditions either of that numbered
|
||||
version or of any later version published by the Free Software
|
||||
Foundation. If the Program does not specify a version number of the
|
||||
GNU Affero General Public License, you may choose any version ever published
|
||||
by the Free Software Foundation.
|
||||
|
||||
If the Program specifies that a proxy can decide which future
|
||||
versions of the GNU Affero General Public License can be used, that proxy's
|
||||
public statement of acceptance of a version permanently authorizes you
|
||||
to choose that version for the Program.
|
||||
|
||||
Later license versions may give you additional or different
|
||||
permissions. However, no additional obligations are imposed on any
|
||||
author or copyright holder as a result of your choosing to follow a
|
||||
later version.
|
||||
|
||||
15. Disclaimer of Warranty.
|
||||
|
||||
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
|
||||
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
|
||||
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
|
||||
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
|
||||
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
|
||||
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
|
||||
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
|
||||
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
|
||||
|
||||
16. Limitation of Liability.
|
||||
|
||||
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
|
||||
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
|
||||
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
|
||||
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
|
||||
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
|
||||
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
|
||||
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
|
||||
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
|
||||
SUCH DAMAGES.
|
||||
|
||||
17. Interpretation of Sections 15 and 16.
|
||||
|
||||
If the disclaimer of warranty and limitation of liability provided
|
||||
above cannot be given local legal effect according to their terms,
|
||||
reviewing courts shall apply local law that most closely approximates
|
||||
an absolute waiver of all civil liability in connection with the
|
||||
Program, unless a warranty or assumption of liability accompanies a
|
||||
copy of the Program in return for a fee.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
How to Apply These Terms to Your New Programs
|
||||
|
||||
If you develop a new program, and you want it to be of the greatest
|
||||
possible use to the public, the best way to achieve this is to make it
|
||||
free software which everyone can redistribute and change under these terms.
|
||||
|
||||
To do so, attach the following notices to the program. It is safest
|
||||
to attach them to the start of each source file to most effectively
|
||||
state the exclusion of warranty; and each file should have at least
|
||||
the "copyright" line and a pointer to where the full notice is found.
|
||||
|
||||
<one line to give the program's name and a brief idea of what it does.>
|
||||
Copyright (C) <year> <name of author>
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
Also add information on how to contact you by electronic and paper mail.
|
||||
|
||||
If your software can interact with users remotely through a computer
|
||||
network, you should also make sure that it provides a way for users to
|
||||
get its source. For example, if your program is a web application, its
|
||||
interface could display a "Source" link that leads users to an archive
|
||||
of the code. There are many ways you could offer source, and different
|
||||
solutions will be better for different programs; see section 13 for the
|
||||
specific requirements.
|
||||
|
||||
You should also get your employer (if you work as a programmer) or school,
|
||||
if any, to sign a "copyright disclaimer" for the program, if necessary.
|
||||
For more information on this, and how to apply and follow the GNU AGPL, see
|
||||
<https://www.gnu.org/licenses/>.
|
||||
49
PKGBUILD
49
PKGBUILD
@@ -1,49 +0,0 @@
|
||||
# Maintainer: vikingowl <christian@nachtigall.dev>
|
||||
pkgname=owlen
|
||||
pkgver=0.1.11
|
||||
pkgrel=1
|
||||
pkgdesc="Terminal User Interface LLM client for Ollama with chat and code assistance features"
|
||||
arch=('x86_64')
|
||||
url="https://somegit.dev/Owlibou/owlen"
|
||||
license=('AGPL-3.0-or-later')
|
||||
depends=('gcc-libs')
|
||||
makedepends=('cargo' 'git')
|
||||
options=(!lto) # avoid LTO-linked ring symbol drop with lld
|
||||
source=("$pkgname-$pkgver.tar.gz::$url/archive/v$pkgver.tar.gz")
|
||||
sha256sums=('cabb1cfdfc247b5d008c6c5f94e13548bcefeba874aae9a9d45aa95ae1c085ec')
|
||||
|
||||
prepare() {
|
||||
cd $pkgname
|
||||
cargo fetch --target "$(rustc -vV | sed -n 's/host: //p')"
|
||||
}
|
||||
|
||||
build() {
|
||||
cd $pkgname
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -C link-arg=-Wl,--no-as-needed"
|
||||
export CARGO_PROFILE_RELEASE_LTO=false
|
||||
export CARGO_TARGET_DIR=target
|
||||
cargo build --frozen --release --all-features
|
||||
}
|
||||
|
||||
check() {
|
||||
cd $pkgname
|
||||
export RUSTFLAGS="${RUSTFLAGS:-} -C link-arg=-Wl,--no-as-needed"
|
||||
cargo test --frozen --all-features
|
||||
}
|
||||
|
||||
package() {
|
||||
cd $pkgname
|
||||
|
||||
# Install binaries
|
||||
install -Dm755 target/release/owlen "$pkgdir/usr/bin/owlen"
|
||||
install -Dm755 target/release/owlen-code "$pkgdir/usr/bin/owlen-code"
|
||||
|
||||
# Install documentation
|
||||
install -Dm644 README.md "$pkgdir/usr/share/doc/$pkgname/README.md"
|
||||
|
||||
# Install built-in themes for reference
|
||||
install -Dm644 themes/README.md "$pkgdir/usr/share/$pkgname/themes/README.md"
|
||||
for theme in themes/*.toml; do
|
||||
install -Dm644 "$theme" "$pkgdir/usr/share/$pkgname/themes/$(basename $theme)"
|
||||
done
|
||||
}
|
||||
172
README.md
172
README.md
@@ -1,172 +0,0 @@
|
||||
# OWLEN
|
||||
|
||||
> Terminal-native assistant for running local language models with a comfortable TUI.
|
||||
|
||||

|
||||

|
||||

|
||||

|
||||
|
||||
## What Is OWLEN?
|
||||
|
||||
OWLEN is a Rust-powered, terminal-first interface for interacting with local and cloud
|
||||
language models. It provides a responsive chat workflow that now routes through a
|
||||
multi-provider manager—handling local Ollama, Ollama Cloud, and future MCP-backed providers—
|
||||
with a focus on developer productivity, vim-style navigation, and seamless session
|
||||
management—all without leaving your terminal.
|
||||
|
||||
## Alpha Status
|
||||
|
||||
This project is currently in **alpha** and under active development. Core features are functional, but expect occasional bugs and breaking changes. Feedback, bug reports, and contributions are very welcome!
|
||||
|
||||
## Screenshots
|
||||
|
||||

|
||||
|
||||
The OWLEN interface features a clean, multi-panel layout with vim-inspired navigation. See more screenshots in the [`images/`](images/) directory.
|
||||
|
||||
## Features
|
||||
|
||||
- **Vim-style Navigation**: Normal, editing, visual, and command modes.
|
||||
- **Streaming Responses**: Real-time token streaming from Ollama.
|
||||
- **Advanced Text Editing**: Multi-line input, history, and clipboard support.
|
||||
- **Session Management**: Save, load, and manage conversations.
|
||||
- **Code Side Panel**: Switch to code mode (`:mode code`) and open files inline with `:open <path>` for LLM-assisted coding.
|
||||
- **Theming System**: 10 built-in themes and support for custom themes.
|
||||
- **Modular Architecture**: Extensible provider system orchestrated by the new `ProviderManager`, ready for additional MCP-backed providers.
|
||||
- **Dual-Source Model Picker**: Merge local and cloud catalogues with real-time availability badges powered by the background health worker.
|
||||
- **Non-Blocking UI Loop**: Asynchronous generation tasks and provider health checks run off-thread, keeping the TUI responsive even while streaming long replies.
|
||||
- **Guided Setup**: `owlen config doctor` upgrades legacy configs and verifies your environment in seconds.
|
||||
|
||||
## Security & Privacy
|
||||
|
||||
Owlen is designed to keep data local by default while still allowing controlled access to remote tooling.
|
||||
|
||||
- **Local-first execution**: All LLM calls flow through the bundled MCP LLM server which talks to a local Ollama instance. If the server is unreachable, Owlen stays usable in “offline mode” and surfaces clear recovery instructions.
|
||||
- **Sandboxed tooling**: Code execution runs in Docker according to the MCP Code Server settings, and future releases will extend this to other OS-level sandboxes (`sandbox-exec` on macOS, Windows job objects).
|
||||
- **Session storage**: Conversations are stored under the platform data directory and can be encrypted at rest. Set `privacy.encrypt_local_data = true` in `config.toml` to enable AES-GCM storage protected by a user-supplied passphrase.
|
||||
- **Network access**: No telemetry is sent. The only outbound requests occur when you explicitly enable remote tooling (e.g., web search) or configure a cloud LLM provider. Each tool is opt-in via `privacy` and `tools` configuration sections.
|
||||
- **Config migrations**: Every saved `config.toml` carries a schema version and is upgraded automatically; deprecated keys trigger warnings so security-related settings are not silently ignored.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
- Rust 1.75+ and Cargo.
|
||||
- A running Ollama instance.
|
||||
- A terminal that supports 256 colors.
|
||||
|
||||
### Installation
|
||||
|
||||
Pick the option that matches your platform and appetite for source builds:
|
||||
|
||||
| Platform | Package / Command | Notes |
|
||||
| --- | --- | --- |
|
||||
| Arch Linux | `yay -S owlen-git` | Builds from the latest `dev` branch via AUR. |
|
||||
| Other Linux | `cargo install --path crates/owlen-cli --locked --force` | Requires Rust 1.75+ and a running Ollama daemon. |
|
||||
| macOS | `cargo install --path crates/owlen-cli --locked --force` | macOS 12+ tested. Install Ollama separately (`brew install ollama`). The binary links against the system OpenSSL – ensure Command Line Tools are installed. |
|
||||
| Windows (experimental) | `cargo install --path crates/owlen-cli --locked --force` | Enable the GNU toolchain (`rustup target add x86_64-pc-windows-gnu`) and install Ollama for Windows preview builds. Some optional tools (e.g., Docker-based code execution) are currently disabled. |
|
||||
|
||||
If you prefer containerised builds, use the provided `Dockerfile` as a base image and copy out `target/release/owlen`.
|
||||
|
||||
Run the helper scripts to sanity-check platform coverage:
|
||||
|
||||
```bash
|
||||
# Windows compatibility smoke test (GNU toolchain)
|
||||
scripts/check-windows.sh
|
||||
|
||||
# Reproduce CI packaging locally (choose a target from .woodpecker.yml)
|
||||
dev/local_build.sh x86_64-unknown-linux-gnu
|
||||
```
|
||||
|
||||
> **Tip (macOS):** On the first launch macOS Gatekeeper may quarantine the binary. Clear the attribute (`xattr -d com.apple.quarantine $(which owlen)`) or build from source locally to avoid notarisation prompts.
|
||||
|
||||
### Running OWLEN
|
||||
|
||||
Make sure Ollama is running, then launch the application:
|
||||
```bash
|
||||
owlen
|
||||
```
|
||||
If you built from source without installing, you can run it with:
|
||||
```bash
|
||||
./target/release/owlen
|
||||
```
|
||||
|
||||
### Updating
|
||||
|
||||
Owlen does not auto-update. Run `owlen upgrade` at any time to print the recommended manual steps (pull the repository and reinstall with `cargo install --path crates/owlen-cli --force`). Arch Linux users can update via the `owlen-git` AUR package.
|
||||
|
||||
## Using the TUI
|
||||
|
||||
OWLEN uses a modal, vim-inspired interface. Press `F1` (available from any mode) or `?` in Normal mode to view the help screen with all keybindings.
|
||||
|
||||
- **Normal Mode**: Navigate with `h/j/k/l`, `w/b`, `gg/G`.
|
||||
- **Editing Mode**: Enter with `i` or `a`. Send messages with `Enter`.
|
||||
- **Command Mode**: Enter with `:`. Access commands like `:quit`, `:w`, `:session save`, `:theme`.
|
||||
- **Quick Exit**: Press `Ctrl+C` twice in Normal mode to quit quickly (first press still cancels active generations).
|
||||
- **Tutorial Command**: Type `:tutorial` any time for a quick summary of the most important keybindings.
|
||||
- **MCP Slash Commands**: Owlen auto-registers zero-argument MCP tools as slash commands—type `/mcp__github__list_prs` (for example) to pull remote context directly into the chat log.
|
||||
|
||||
Model discovery commands worth remembering:
|
||||
|
||||
- `:models --local` or `:models --cloud` jump directly to the corresponding section in the picker.
|
||||
- `:cloud setup [--force-cloud-base-url]` stores your cloud API key without clobbering an existing local base URL (unless you opt in with the flag).
|
||||
When a catalogue is unreachable, Owlen now tags the picker with `Local unavailable` / `Cloud unavailable` so you can recover without guessing.
|
||||
|
||||
## Documentation
|
||||
|
||||
For more detailed information, please refer to the following documents:
|
||||
|
||||
- **[CONTRIBUTING.md](CONTRIBUTING.md)**: Guidelines for contributing to the project.
|
||||
- **[CHANGELOG.md](CHANGELOG.md)**: A log of changes for each version.
|
||||
- **[docs/architecture.md](docs/architecture.md)**: An overview of the project's architecture.
|
||||
- **[docs/troubleshooting.md](docs/troubleshooting.md)**: Help with common issues.
|
||||
- **[docs/repo-map.md](docs/repo-map.md)**: Snapshot of the workspace layout and key crates.
|
||||
- **[docs/provider-implementation.md](docs/provider-implementation.md)**: Trait-level details for implementing providers.
|
||||
- **[docs/adding-providers.md](docs/adding-providers.md)**: Step-by-step checklist for wiring a provider into the multi-provider architecture and test suite.
|
||||
- **Experimental providers staging area**: [crates/providers/experimental/README.md](crates/providers/experimental/README.md) records the placeholder crates (OpenAI, Anthropic, Gemini) and their current status.
|
||||
- **[docs/platform-support.md](docs/platform-support.md)**: Current OS support matrix and cross-check instructions.
|
||||
|
||||
## Configuration
|
||||
|
||||
OWLEN stores its configuration in the standard platform-specific config directory:
|
||||
|
||||
| Platform | Location |
|
||||
|----------|----------|
|
||||
| Linux | `~/.config/owlen/config.toml` |
|
||||
| macOS | `~/Library/Application Support/owlen/config.toml` |
|
||||
| Windows | `%APPDATA%\owlen\config.toml` |
|
||||
|
||||
Use `owlen config path` to print the exact location on your machine and `owlen config doctor` to migrate a legacy config automatically.
|
||||
You can also add custom themes alongside the config directory (e.g., `~/.config/owlen/themes/`).
|
||||
|
||||
See the [themes/README.md](themes/README.md) for more details on theming.
|
||||
|
||||
## Testing
|
||||
|
||||
Owlen uses standard Rust tooling for verification. Run the full test suite with:
|
||||
|
||||
```bash
|
||||
cargo test
|
||||
```
|
||||
|
||||
Unit tests cover the command palette state machine, agent response parsing, and key MCP abstractions. Formatting and lint checks can be run with `cargo fmt --all` and `cargo clippy` respectively.
|
||||
|
||||
## Roadmap
|
||||
|
||||
Upcoming milestones focus on feature parity with modern code assistants while keeping Owlen local-first:
|
||||
|
||||
1. **Phase 11 – MCP client enhancements**: `owlen mcp add/list/remove`, resource references (`@github:issue://123`), and MCP prompt slash commands.
|
||||
2. **Phase 12 – Approval & sandboxing**: Three-tier approval modes plus platform-specific sandboxes (Docker, `sandbox-exec`, Windows job objects).
|
||||
3. **Phase 13 – Project documentation system**: Automatic `OWLEN.md` generation, contextual updates, and nested project support.
|
||||
4. **Phase 15 – Provider expansion**: OpenAI, Anthropic, and other cloud providers layered onto the existing Ollama-first architecture.
|
||||
|
||||
See `AGENTS.md` for the long-form roadmap and design notes.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are highly welcome! Please see our **[Contributing Guide](CONTRIBUTING.md)** for details on how to get started, including our code style, commit conventions, and pull request process.
|
||||
|
||||
## License
|
||||
|
||||
This project is licensed under the GNU Affero General Public License v3.0. See the [LICENSE](LICENSE) file for details.
|
||||
For commercial or proprietary integrations that cannot adopt AGPL, please reach out to the maintainers to discuss alternative licensing arrangements.
|
||||
40
SECURITY.md
40
SECURITY.md
@@ -1,40 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
We are currently in a pre-release phase, so only the latest version is actively supported. As we move towards a 1.0 release, this policy will be updated with specific version support.
|
||||
|
||||
| Version | Supported |
|
||||
| ------- | ------------------ |
|
||||
| < 1.0 | :white_check_mark: |
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
The Owlen team and community take all security vulnerabilities seriously. Thank you for improving the security of our project. We appreciate your efforts and responsible disclosure and will make every effort to acknowledge your contributions.
|
||||
|
||||
To report a security vulnerability, please email the project lead at [security@owlibou.com](mailto:security@owlibou.com) with a detailed description of the issue, the steps to reproduce it, and any affected versions.
|
||||
|
||||
You will receive a response from us within 48 hours. If the issue is confirmed, we will release a patch as soon as possible, depending on the complexity of the issue.
|
||||
|
||||
Please do not report security vulnerabilities through public GitHub issues.
|
||||
|
||||
## Design Overview
|
||||
|
||||
Owlen ships with a local-first architecture:
|
||||
|
||||
- **Process isolation** – The TUI speaks to language models through a separate MCP LLM server. Tool execution (code, web, filesystem) occurs in dedicated MCP processes so a crash or hang cannot take down the UI.
|
||||
- **Sandboxing** – The MCP Code Server executes snippets in Docker containers. Upcoming releases will extend this to platform sandboxes (`sandbox-exec` on macOS, Windows job objects) as described in our roadmap.
|
||||
- **Network posture** – No telemetry is emitted. The application only reaches the network when a user explicitly enables remote tools (web search, remote MCP servers) or configures cloud providers. All tools require allow-listing in `config.toml`.
|
||||
|
||||
## Data Handling
|
||||
|
||||
- **Sessions** – Conversations are stored in the user’s data directory (`~/.local/share/owlen` on Linux, equivalent paths on macOS/Windows). Enable `privacy.encrypt_local_data = true` to wrap the session store in AES-GCM encryption protected by a passphrase (`OWLEN_MASTER_PASSWORD` or an interactive prompt).
|
||||
- **Credentials** – API tokens are resolved from the config file or environment variables at runtime and are never written to logs.
|
||||
- **Remote calls** – When remote search or cloud LLM tooling is on, only the minimum payload (prompt, tool arguments) is sent. All outbound requests go through the MCP servers so they can be audited or disabled centrally.
|
||||
|
||||
## Supply-Chain Safeguards
|
||||
|
||||
- The repository includes a git `pre-commit` configuration that runs `cargo fmt`, `cargo check`, and `cargo clippy -- -D warnings` on every commit.
|
||||
- Pull requests generated with the assistance of AI tooling must receive manual maintainer review before merging. Contributors are asked to declare AI involvement in their PR description so maintainers can double-check the changes.
|
||||
|
||||
Additional recommendations for operators (e.g., running Owlen on shared systems) are maintained in `docs/security.md` (planned) and the issue tracker.
|
||||
29
config.toml
29
config.toml
@@ -1,29 +0,0 @@
|
||||
[general]
|
||||
default_provider = "ollama_local"
|
||||
default_model = "llama3.2:latest"
|
||||
|
||||
[privacy]
|
||||
encrypt_local_data = true
|
||||
|
||||
[providers.ollama_local]
|
||||
enabled = true
|
||||
provider_type = "ollama"
|
||||
base_url = "http://localhost:11434"
|
||||
|
||||
[providers.ollama_cloud]
|
||||
enabled = false
|
||||
provider_type = "ollama_cloud"
|
||||
base_url = "https://ollama.com"
|
||||
api_key_env = "OLLAMA_CLOUD_API_KEY"
|
||||
|
||||
[providers.openai]
|
||||
enabled = false
|
||||
provider_type = "openai"
|
||||
base_url = "https://api.openai.com/v1"
|
||||
api_key_env = "OPENAI_API_KEY"
|
||||
|
||||
[providers.anthropic]
|
||||
enabled = false
|
||||
provider_type = "anthropic"
|
||||
base_url = "https://api.anthropic.com/v1"
|
||||
api_key_env = "ANTHROPIC_API_KEY"
|
||||
22
crates/app/cli/.gitignore
vendored
Normal file
22
crates/app/cli/.gitignore
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/target
|
||||
### Rust template
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
### rust-analyzer template
|
||||
# Can be generated by other build systems other than cargo (ex: bazelbuild/rust_rules)
|
||||
rust-project.json
|
||||
|
||||
|
||||
33
crates/app/cli/Cargo.toml
Normal file
33
crates/app/cli/Cargo.toml
Normal file
@@ -0,0 +1,33 @@
|
||||
[package]
|
||||
name = "owlen"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
clap = { version = "4.5", features = ["derive"] }
|
||||
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
agent-core = { path = "../../core/agent" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
llm-ollama = { path = "../../llm/ollama" }
|
||||
tools-fs = { path = "../../tools/fs" }
|
||||
tools-bash = { path = "../../tools/bash" }
|
||||
tools-slash = { path = "../../tools/slash" }
|
||||
config-agent = { package = "config-agent", path = "../../platform/config" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
hooks = { path = "../../platform/hooks" }
|
||||
plugins = { path = "../../platform/plugins" }
|
||||
ui = { path = "../ui" }
|
||||
atty = "0.2"
|
||||
futures-util = "0.3.31"
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2.0"
|
||||
predicates = "3.1"
|
||||
httpmock = "0.7"
|
||||
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
tempfile = "3.23.0"
|
||||
382
crates/app/cli/src/commands.rs
Normal file
382
crates/app/cli/src/commands.rs
Normal file
@@ -0,0 +1,382 @@
|
||||
//! Built-in commands for CLI and TUI
|
||||
//!
|
||||
//! Provides handlers for /help, /mcp, /hooks, /clear, and other built-in commands.
|
||||
|
||||
use ui::{CommandInfo, CommandOutput, OutputFormat, TreeNode, ListItem};
|
||||
use permissions::PermissionManager;
|
||||
use hooks::HookManager;
|
||||
use plugins::PluginManager;
|
||||
use agent_core::SessionStats;
|
||||
|
||||
/// Result of executing a built-in command
|
||||
pub enum CommandResult {
|
||||
/// Command produced output to display
|
||||
Output(CommandOutput),
|
||||
/// Command was handled but produced no output (e.g., /clear)
|
||||
Handled,
|
||||
/// Command was not recognized
|
||||
NotFound,
|
||||
/// Command needs to exit the session
|
||||
Exit,
|
||||
}
|
||||
|
||||
/// Built-in command handler
|
||||
pub struct BuiltinCommands<'a> {
|
||||
plugin_manager: Option<&'a PluginManager>,
|
||||
hook_manager: Option<&'a HookManager>,
|
||||
permission_manager: Option<&'a PermissionManager>,
|
||||
stats: Option<&'a SessionStats>,
|
||||
}
|
||||
|
||||
impl<'a> BuiltinCommands<'a> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
plugin_manager: None,
|
||||
hook_manager: None,
|
||||
permission_manager: None,
|
||||
stats: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_plugins(mut self, pm: &'a PluginManager) -> Self {
|
||||
self.plugin_manager = Some(pm);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_hooks(mut self, hm: &'a HookManager) -> Self {
|
||||
self.hook_manager = Some(hm);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_permissions(mut self, perms: &'a PermissionManager) -> Self {
|
||||
self.permission_manager = Some(perms);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_stats(mut self, stats: &'a SessionStats) -> Self {
|
||||
self.stats = Some(stats);
|
||||
self
|
||||
}
|
||||
|
||||
/// Execute a built-in command
|
||||
pub fn execute(&self, command: &str) -> CommandResult {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
let cmd = parts.first().map(|s| s.trim_start_matches('/'));
|
||||
|
||||
match cmd {
|
||||
Some("help") | Some("?") => CommandResult::Output(self.help()),
|
||||
Some("mcp") => CommandResult::Output(self.mcp()),
|
||||
Some("hooks") => CommandResult::Output(self.hooks()),
|
||||
Some("plugins") => CommandResult::Output(self.plugins()),
|
||||
Some("status") => CommandResult::Output(self.status()),
|
||||
Some("permissions") | Some("perms") => CommandResult::Output(self.permissions()),
|
||||
Some("clear") => CommandResult::Handled,
|
||||
Some("exit") | Some("quit") | Some("q") => CommandResult::Exit,
|
||||
_ => CommandResult::NotFound,
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate help output
|
||||
fn help(&self) -> CommandOutput {
|
||||
let mut commands = vec![
|
||||
// Built-in commands
|
||||
CommandInfo::new("help", "Show available commands", "builtin"),
|
||||
CommandInfo::new("clear", "Clear the screen", "builtin"),
|
||||
CommandInfo::new("status", "Show session status", "builtin"),
|
||||
CommandInfo::new("permissions", "Show permission settings", "builtin"),
|
||||
CommandInfo::new("mcp", "List MCP servers and tools", "builtin"),
|
||||
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
|
||||
CommandInfo::new("plugins", "Show loaded plugins", "builtin"),
|
||||
CommandInfo::new("checkpoint", "Save session state", "builtin"),
|
||||
CommandInfo::new("checkpoints", "List saved checkpoints", "builtin"),
|
||||
CommandInfo::new("rewind", "Restore from checkpoint", "builtin"),
|
||||
CommandInfo::new("compact", "Compact conversation context", "builtin"),
|
||||
CommandInfo::new("exit", "Exit the session", "builtin"),
|
||||
];
|
||||
|
||||
// Add plugin commands
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
for cmd_name in plugin.all_command_names() {
|
||||
commands.push(CommandInfo::new(
|
||||
&cmd_name,
|
||||
&format!("Plugin command from {}", plugin.manifest.name),
|
||||
&format!("plugin:{}", plugin.manifest.name),
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CommandOutput::help_table(&commands)
|
||||
}
|
||||
|
||||
/// Generate MCP servers output
|
||||
fn mcp(&self) -> CommandOutput {
|
||||
let mut servers: Vec<(String, Vec<String>)> = vec![];
|
||||
|
||||
// Get MCP servers from plugins
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
// Check for .mcp.json in plugin directory
|
||||
let mcp_path = plugin.base_path.join(".mcp.json");
|
||||
if mcp_path.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&mcp_path) {
|
||||
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
|
||||
if let Some(mcpservers) = config.get("mcpServers").and_then(|v| v.as_object()) {
|
||||
for (name, _) in mcpservers {
|
||||
servers.push((
|
||||
format!("{} ({})", name, plugin.manifest.name),
|
||||
vec!["(connect to discover tools)".to_string()],
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if servers.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "No MCP servers configured.\n\nAdd MCP servers in plugin .mcp.json files.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::mcp_tree(&servers)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate hooks output
|
||||
fn hooks(&self) -> CommandOutput {
|
||||
let mut hooks_list: Vec<(String, String, bool)> = vec![];
|
||||
|
||||
// Check for file-based hooks in .owlen/hooks/
|
||||
let hook_events = ["PreToolUse", "PostToolUse", "SessionStart", "SessionEnd",
|
||||
"UserPromptSubmit", "PreCompact", "Stop", "SubagentStop"];
|
||||
|
||||
for event in hook_events {
|
||||
let path = format!(".owlen/hooks/{}", event);
|
||||
let exists = std::path::Path::new(&path).exists();
|
||||
if exists {
|
||||
hooks_list.push((event.to_string(), path, true));
|
||||
}
|
||||
}
|
||||
|
||||
// Get hooks from plugins
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
for plugin in pm.plugins() {
|
||||
if let Some(hooks_config) = plugin.load_hooks_config().ok().flatten() {
|
||||
// hooks_config.hooks is HashMap<String, Vec<HookMatcher>>
|
||||
for (event_name, matchers) in &hooks_config.hooks {
|
||||
for matcher in matchers {
|
||||
for hook_def in &matcher.hooks {
|
||||
let cmd = hook_def.command.as_deref()
|
||||
.or(hook_def.prompt.as_deref())
|
||||
.unwrap_or("(no command)");
|
||||
hooks_list.push((
|
||||
event_name.clone(),
|
||||
format!("{}: {}", plugin.manifest.name, cmd),
|
||||
true,
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if hooks_list.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "No hooks configured.\n\nAdd hooks in .owlen/hooks/ or plugin hooks.json files.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::hooks_list(&hooks_list)
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate plugins output
|
||||
fn plugins(&self) -> CommandOutput {
|
||||
if let Some(pm) = self.plugin_manager {
|
||||
let plugins = pm.plugins();
|
||||
if plugins.is_empty() {
|
||||
return CommandOutput::new(OutputFormat::Text {
|
||||
content: "No plugins loaded.\n\nPlace plugins in:\n - ~/.config/owlen/plugins (user)\n - .owlen/plugins (project)".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
// Build tree of plugins and their components
|
||||
let children: Vec<TreeNode> = plugins.iter().map(|p| {
|
||||
let mut plugin_children = vec![];
|
||||
|
||||
let commands = p.all_command_names();
|
||||
if !commands.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Commands").with_children(
|
||||
commands.iter().map(|c| TreeNode::new(format!("/{}", c))).collect()
|
||||
));
|
||||
}
|
||||
|
||||
let agents = p.all_agent_names();
|
||||
if !agents.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Agents").with_children(
|
||||
agents.iter().map(|a| TreeNode::new(a)).collect()
|
||||
));
|
||||
}
|
||||
|
||||
let skills = p.all_skill_names();
|
||||
if !skills.is_empty() {
|
||||
plugin_children.push(TreeNode::new("Skills").with_children(
|
||||
skills.iter().map(|s| TreeNode::new(s)).collect()
|
||||
));
|
||||
}
|
||||
|
||||
TreeNode::new(format!("{} v{}", p.manifest.name, p.manifest.version))
|
||||
.with_children(plugin_children)
|
||||
}).collect();
|
||||
|
||||
CommandOutput::new(OutputFormat::Tree {
|
||||
root: TreeNode::new("Loaded Plugins").with_children(children),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Plugin manager not available.".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate status output
|
||||
fn status(&self) -> CommandOutput {
|
||||
let mut items = vec![];
|
||||
|
||||
if let Some(stats) = self.stats {
|
||||
items.push(ListItem {
|
||||
text: format!("Messages: {}", stats.total_messages),
|
||||
marker: Some("📊".to_string()),
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Tool Calls: {}", stats.total_tool_calls),
|
||||
marker: Some("🔧".to_string()),
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Est. Tokens: ~{}", stats.estimated_tokens),
|
||||
marker: Some("📝".to_string()),
|
||||
style: None,
|
||||
});
|
||||
let uptime = stats.start_time.elapsed().unwrap_or_default();
|
||||
items.push(ListItem {
|
||||
text: format!("Uptime: {}", SessionStats::format_duration(uptime)),
|
||||
marker: Some("⏱️".to_string()),
|
||||
style: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Some(perms) = self.permission_manager {
|
||||
items.push(ListItem {
|
||||
text: format!("Mode: {:?}", perms.mode()),
|
||||
marker: Some("🔒".to_string()),
|
||||
style: None,
|
||||
});
|
||||
}
|
||||
|
||||
if items.is_empty() {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Session status not available.".to_string(),
|
||||
})
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::List { items })
|
||||
}
|
||||
}
|
||||
|
||||
/// Generate permissions output
|
||||
fn permissions(&self) -> CommandOutput {
|
||||
if let Some(perms) = self.permission_manager {
|
||||
let mode = perms.mode();
|
||||
let mode_str = format!("{:?}", mode);
|
||||
|
||||
let mut items = vec![
|
||||
ListItem {
|
||||
text: format!("Current Mode: {}", mode_str),
|
||||
marker: Some("🔒".to_string()),
|
||||
style: None,
|
||||
},
|
||||
];
|
||||
|
||||
// Add tool permissions summary
|
||||
let (read_status, write_status, bash_status) = match mode {
|
||||
permissions::Mode::Plan => ("✅ Allowed", "❓ Ask", "❓ Ask"),
|
||||
permissions::Mode::AcceptEdits => ("✅ Allowed", "✅ Allowed", "❓ Ask"),
|
||||
permissions::Mode::Code => ("✅ Allowed", "✅ Allowed", "✅ Allowed"),
|
||||
};
|
||||
|
||||
items.push(ListItem {
|
||||
text: format!("Read/Grep/Glob: {}", read_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Write/Edit: {}", write_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
items.push(ListItem {
|
||||
text: format!("Bash: {}", bash_status),
|
||||
marker: None,
|
||||
style: None,
|
||||
});
|
||||
|
||||
CommandOutput::new(OutputFormat::List { items })
|
||||
} else {
|
||||
CommandOutput::new(OutputFormat::Text {
|
||||
content: "Permission manager not available.".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BuiltinCommands<'_> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_help_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
match handler.execute("/help") {
|
||||
CommandResult::Output(output) => {
|
||||
match output.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
assert!(!headers.is_empty());
|
||||
assert!(!rows.is_empty());
|
||||
}
|
||||
_ => panic!("Expected Table format"),
|
||||
}
|
||||
}
|
||||
_ => panic!("Expected Output result"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_exit_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/exit"), CommandResult::Exit));
|
||||
assert!(matches!(handler.execute("/quit"), CommandResult::Exit));
|
||||
assert!(matches!(handler.execute("/q"), CommandResult::Exit));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clear_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/clear"), CommandResult::Handled));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_command() {
|
||||
let handler = BuiltinCommands::new();
|
||||
assert!(matches!(handler.execute("/unknown"), CommandResult::NotFound));
|
||||
}
|
||||
}
|
||||
873
crates/app/cli/src/main.rs
Normal file
873
crates/app/cli/src/main.rs
Normal file
@@ -0,0 +1,873 @@
|
||||
mod commands;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use config_agent::load_settings;
|
||||
use hooks::{HookEvent, HookManager, HookResult};
|
||||
use llm_core::ChatOptions;
|
||||
use llm_ollama::OllamaClient;
|
||||
use permissions::{PermissionDecision, Tool};
|
||||
use plugins::PluginManager;
|
||||
use serde::Serialize;
|
||||
use std::io::Write;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
pub use commands::{BuiltinCommands, CommandResult};
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum)]
|
||||
enum OutputFormat {
|
||||
Text,
|
||||
Json,
|
||||
StreamJson,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct SessionOutput {
|
||||
session_id: String,
|
||||
messages: Vec<serde_json::Value>,
|
||||
stats: Stats,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct Stats {
|
||||
total_tokens: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
prompt_tokens: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
completion_tokens: Option<u64>,
|
||||
duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct StreamEvent {
|
||||
#[serde(rename = "type")]
|
||||
event_type: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
session_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
stats: Option<Stats>,
|
||||
}
|
||||
|
||||
/// Application context shared across the session
|
||||
pub struct AppContext {
|
||||
pub plugin_manager: PluginManager,
|
||||
pub config: config_agent::Settings,
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
pub fn new() -> Result<Self> {
|
||||
let config = load_settings(None).unwrap_or_default();
|
||||
|
||||
let mut plugin_manager = PluginManager::new();
|
||||
// Non-fatal: just log warnings, don't fail startup
|
||||
if let Err(e) = plugin_manager.load_all() {
|
||||
eprintln!("Warning: Failed to load some plugins: {}", e);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
plugin_manager,
|
||||
config,
|
||||
})
|
||||
}
|
||||
|
||||
/// Print loaded plugins and available commands
|
||||
pub fn print_plugin_info(&self) {
|
||||
let plugins = self.plugin_manager.plugins();
|
||||
if !plugins.is_empty() {
|
||||
println!("\nLoaded {} plugin(s):", plugins.len());
|
||||
for plugin in plugins {
|
||||
println!(" - {} v{}", plugin.manifest.name, plugin.manifest.version);
|
||||
if let Some(desc) = &plugin.manifest.description {
|
||||
println!(" {}", desc);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let commands = self.plugin_manager.all_commands();
|
||||
if !commands.is_empty() {
|
||||
println!("\nAvailable plugin commands:");
|
||||
for (name, _path) in &commands {
|
||||
println!(" /{}", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn generate_session_id() -> String {
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_millis();
|
||||
format!("session-{}", timestamp)
|
||||
}
|
||||
|
||||
fn output_tool_result(
|
||||
format: OutputFormat,
|
||||
tool: &str,
|
||||
result: serde_json::Value,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
match format {
|
||||
OutputFormat::Text => {
|
||||
// For text, just print the result as-is
|
||||
if let Some(s) = result.as_str() {
|
||||
println!("{}", s);
|
||||
} else {
|
||||
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||
}
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
let output = SessionOutput {
|
||||
session_id: session_id.to_string(),
|
||||
messages: vec![],
|
||||
stats: Stats {
|
||||
total_tokens: 0,
|
||||
prompt_tokens: None,
|
||||
completion_tokens: None,
|
||||
duration_ms: 0,
|
||||
},
|
||||
result: Some(result),
|
||||
tool: Some(tool.to_string()),
|
||||
};
|
||||
println!("{}", serde_json::to_string(&output)?);
|
||||
}
|
||||
OutputFormat::StreamJson => {
|
||||
// For stream-json, emit session_start, result, and session_end
|
||||
let session_start = StreamEvent {
|
||||
event_type: "session_start".to_string(),
|
||||
session_id: Some(session_id.to_string()),
|
||||
content: None,
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&session_start)?);
|
||||
|
||||
let result_event = StreamEvent {
|
||||
event_type: "tool_result".to_string(),
|
||||
session_id: None,
|
||||
content: Some(serde_json::to_string(&result)?),
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&result_event)?);
|
||||
|
||||
let session_end = StreamEvent {
|
||||
event_type: "session_end".to_string(),
|
||||
session_id: None,
|
||||
content: None,
|
||||
stats: Some(Stats {
|
||||
total_tokens: 0,
|
||||
prompt_tokens: None,
|
||||
completion_tokens: None,
|
||||
duration_ms: 0,
|
||||
}),
|
||||
};
|
||||
println!("{}", serde_json::to_string(&session_end)?);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(clap::Subcommand, Debug)]
|
||||
enum Cmd {
|
||||
Read { path: String },
|
||||
Glob { pattern: String },
|
||||
Grep { root: String, pattern: String },
|
||||
Write { path: String, content: String },
|
||||
Edit { path: String, old_string: String, new_string: String },
|
||||
Bash { command: String, #[arg(long)] timeout: Option<u64> },
|
||||
Slash { command_name: String, args: Vec<String> },
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "code", version)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
ollama_url: Option<String>,
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
#[arg(long)]
|
||||
print: bool,
|
||||
/// Override the permission mode (plan, acceptEdits, code)
|
||||
#[arg(long)]
|
||||
mode: Option<String>,
|
||||
/// Output format (text, json, stream-json)
|
||||
#[arg(long, value_enum, default_value = "text")]
|
||||
output_format: OutputFormat,
|
||||
/// Disable TUI and use legacy text-based REPL
|
||||
#[arg(long)]
|
||||
no_tui: bool,
|
||||
#[arg()]
|
||||
prompt: Vec<String>,
|
||||
#[command(subcommand)]
|
||||
cmd: Option<Cmd>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
color_eyre::install()?;
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialize application context with plugins
|
||||
let app_context = AppContext::new()?;
|
||||
let mut settings = app_context.config.clone();
|
||||
|
||||
// Override mode if specified via CLI
|
||||
if let Some(mode) = args.mode {
|
||||
settings.mode = mode;
|
||||
}
|
||||
|
||||
// Create permission manager from settings
|
||||
let perms = settings.create_permission_manager();
|
||||
|
||||
// Create hook manager
|
||||
let mut hook_mgr = HookManager::new(".");
|
||||
|
||||
// Register plugin hooks
|
||||
for plugin in app_context.plugin_manager.plugins() {
|
||||
if let Ok(Some(hooks_config)) = plugin.load_hooks_config() {
|
||||
for (event, command, pattern, timeout) in plugin.register_hooks_with_manager(&hooks_config) {
|
||||
hook_mgr.register_hook(event, command, pattern, timeout);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate session ID
|
||||
let session_id = generate_session_id();
|
||||
let output_format = args.output_format;
|
||||
|
||||
if let Some(cmd) = args.cmd {
|
||||
match cmd {
|
||||
Cmd::Read { path } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::Read, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Read".to_string(),
|
||||
args: serde_json::json!({"path": &path}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Read operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
let s = tools_fs::read_file(&path)?;
|
||||
output_tool_result(output_format, "Read", serde_json::json!(s), &session_id)?;
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Read operation requires approval. Use --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Read operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Glob { pattern } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::Glob, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Glob".to_string(),
|
||||
args: serde_json::json!({"pattern": &pattern}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Glob operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
for p in tools_fs::glob_list(&pattern)? {
|
||||
println!("{}", p);
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Glob operation requires approval. Use --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Glob operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Grep { root, pattern } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::Grep, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Grep".to_string(),
|
||||
args: serde_json::json!({"root": &root, "pattern": &pattern}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Grep operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
for (path, line_number, text) in tools_fs::grep(&root, &pattern)? {
|
||||
println!("{path}:{line_number}:{text}")
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Grep operation requires approval. Use --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Grep operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Write { path, content } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::Write, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Write".to_string(),
|
||||
args: serde_json::json!({"path": &path, "content": &content}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Write operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
tools_fs::write_file(&path, &content)?;
|
||||
println!("File written: {}", path);
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Write operation requires approval. Use --mode acceptEdits or --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Write operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Edit { path, old_string, new_string } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::Edit, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Edit".to_string(),
|
||||
args: serde_json::json!({"path": &path, "old_string": &old_string, "new_string": &new_string}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Edit operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
tools_fs::edit_file(&path, &old_string, &new_string)?;
|
||||
println!("File edited: {}", path);
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Edit operation requires approval. Use --mode acceptEdits or --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Edit operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Bash { command, timeout } => {
|
||||
// Check permission with command context for pattern matching
|
||||
match perms.check(Tool::Bash, Some(&command)) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "Bash".to_string(),
|
||||
args: serde_json::json!({"command": &command, "timeout": timeout}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied Bash operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
let mut session = tools_bash::BashSession::new().await?;
|
||||
let output = session.execute(&command, timeout).await?;
|
||||
|
||||
// Print stdout
|
||||
if !output.stdout.is_empty() {
|
||||
print!("{}", output.stdout);
|
||||
}
|
||||
|
||||
// Print stderr to stderr
|
||||
if !output.stderr.is_empty() {
|
||||
eprint!("{}", output.stderr);
|
||||
}
|
||||
|
||||
session.close().await?;
|
||||
|
||||
// Exit with same code as command
|
||||
if !output.success {
|
||||
std::process::exit(output.exit_code);
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Bash operation requires approval. Use --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Bash operation is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
Cmd::Slash { command_name, args } => {
|
||||
// Check permission
|
||||
match perms.check(Tool::SlashCommand, None) {
|
||||
PermissionDecision::Allow => {
|
||||
// Check PreToolUse hook
|
||||
let event = HookEvent::PreToolUse {
|
||||
tool: "SlashCommand".to_string(),
|
||||
args: serde_json::json!({"command_name": &command_name, "args": &args}),
|
||||
};
|
||||
match hook_mgr.execute(&event, Some(5000)).await? {
|
||||
HookResult::Deny => {
|
||||
return Err(eyre!("Hook denied SlashCommand operation"));
|
||||
}
|
||||
HookResult::Allow => {}
|
||||
}
|
||||
|
||||
// Look for command file in .owlen/commands/ first
|
||||
let local_command_path = format!(".owlen/commands/{}.md", command_name);
|
||||
|
||||
// Try local commands first, then plugin commands
|
||||
let content = if let Ok(c) = tools_fs::read_file(&local_command_path) {
|
||||
c
|
||||
} else if let Some(plugin_path) = app_context.plugin_manager.all_commands().get(&command_name) {
|
||||
// Found in plugins
|
||||
tools_fs::read_file(&plugin_path.to_string_lossy())?
|
||||
} else {
|
||||
return Err(eyre!(
|
||||
"Slash command '{}' not found in .owlen/commands/ or plugins",
|
||||
command_name
|
||||
));
|
||||
};
|
||||
|
||||
// Parse with arguments
|
||||
let args_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
|
||||
let slash_cmd = tools_slash::parse_slash_command(&content, &args_refs)?;
|
||||
|
||||
// Resolve file references
|
||||
let resolved_body = slash_cmd.resolve_file_refs()?;
|
||||
|
||||
// Print the resolved command body
|
||||
println!("{}", resolved_body);
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
PermissionDecision::Ask => {
|
||||
return Err(eyre!(
|
||||
"Permission denied: Slash command requires approval. Use --mode code to allow."
|
||||
));
|
||||
}
|
||||
PermissionDecision::Deny => {
|
||||
return Err(eyre!("Permission denied: Slash command is blocked."));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let model = args.model.unwrap_or(settings.model.clone());
|
||||
let api_key = args.api_key.or(settings.api_key.clone());
|
||||
|
||||
// Use Ollama Cloud when model has "-cloud" suffix AND API key is set
|
||||
let use_cloud = model.ends_with("-cloud") && api_key.is_some();
|
||||
let client = if use_cloud {
|
||||
OllamaClient::with_cloud().with_api_key(api_key.unwrap())
|
||||
} else {
|
||||
let base_url = args.ollama_url.unwrap_or(settings.ollama_url.clone());
|
||||
let mut client = OllamaClient::new(base_url);
|
||||
if let Some(key) = api_key {
|
||||
client = client.with_api_key(key);
|
||||
}
|
||||
client
|
||||
};
|
||||
let opts = ChatOptions::new(model);
|
||||
|
||||
// Check if interactive mode (no prompt provided)
|
||||
if args.prompt.is_empty() {
|
||||
// Use TUI mode unless --no-tui flag is set or not a TTY
|
||||
if !args.no_tui && atty::is(atty::Stream::Stdout) {
|
||||
// Launch TUI
|
||||
// Note: For now, TUI doesn't use plugin manager directly
|
||||
// In the future, we'll integrate plugin commands into TUI
|
||||
return ui::run(client, opts, perms, settings).await;
|
||||
}
|
||||
|
||||
// Legacy text-based REPL
|
||||
println!("🤖 Owlen Interactive Mode");
|
||||
println!("Model: {}", opts.model);
|
||||
println!("Mode: {:?}", settings.mode);
|
||||
|
||||
// Show loaded plugins
|
||||
let plugins = app_context.plugin_manager.plugins();
|
||||
if !plugins.is_empty() {
|
||||
println!("Plugins: {} loaded", plugins.len());
|
||||
}
|
||||
|
||||
println!("Type your message or /help for commands. Press Ctrl+C to exit.\n");
|
||||
|
||||
use std::io::{stdin, BufRead};
|
||||
let stdin = stdin();
|
||||
let mut lines = stdin.lock().lines();
|
||||
let mut stats = agent_core::SessionStats::new();
|
||||
let mut history = agent_core::SessionHistory::new();
|
||||
let mut checkpoint_mgr = agent_core::CheckpointManager::new(
|
||||
std::path::PathBuf::from(".owlen/checkpoints")
|
||||
);
|
||||
|
||||
loop {
|
||||
print!("> ");
|
||||
std::io::stdout().flush().ok();
|
||||
|
||||
if let Some(Ok(line)) = lines.next() {
|
||||
let input = line.trim();
|
||||
if input.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Handle slash commands
|
||||
if input.starts_with('/') {
|
||||
match input {
|
||||
"/help" => {
|
||||
println!("\n📖 Available Commands:");
|
||||
println!(" /help - Show this help message");
|
||||
println!(" /status - Show session status");
|
||||
println!(" /permissions - Show permission settings");
|
||||
println!(" /cost - Show token usage and timing");
|
||||
println!(" /history - Show conversation history");
|
||||
println!(" /checkpoint - Save current session state");
|
||||
println!(" /checkpoints - List all saved checkpoints");
|
||||
println!(" /rewind <id> - Restore session from checkpoint");
|
||||
println!(" /clear - Clear conversation history");
|
||||
println!(" /plugins - Show loaded plugins and commands");
|
||||
println!(" /exit - Exit interactive mode");
|
||||
|
||||
// Show plugin commands if any are loaded
|
||||
let plugin_commands = app_context.plugin_manager.all_commands();
|
||||
if !plugin_commands.is_empty() {
|
||||
println!("\n📦 Plugin Commands:");
|
||||
for (name, _path) in &plugin_commands {
|
||||
println!(" /{}", name);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/status" => {
|
||||
println!("\n📊 Session Status:");
|
||||
println!(" Model: {}", opts.model);
|
||||
println!(" Mode: {:?}", settings.mode);
|
||||
println!(" Messages: {}", stats.total_messages);
|
||||
println!(" Tools: {} calls", stats.total_tool_calls);
|
||||
let elapsed = stats.start_time.elapsed().unwrap_or_default();
|
||||
println!(" Uptime: {}", agent_core::SessionStats::format_duration(elapsed));
|
||||
}
|
||||
"/permissions" => {
|
||||
println!("\n🔒 Permission Settings:");
|
||||
println!(" Mode: {:?}", perms.mode());
|
||||
println!("\n Read-only tools: Read, Grep, Glob, NotebookRead");
|
||||
match perms.mode() {
|
||||
permissions::Mode::Plan => {
|
||||
println!(" ✅ Allowed (plan mode)");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ❓ Ask permission");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ❓ Ask permission");
|
||||
}
|
||||
permissions::Mode::AcceptEdits => {
|
||||
println!(" ✅ Allowed");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ✅ Allowed (acceptEdits mode)");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ❓ Ask permission");
|
||||
}
|
||||
permissions::Mode::Code => {
|
||||
println!(" ✅ Allowed");
|
||||
println!("\n Write tools: Write, Edit, NotebookEdit");
|
||||
println!(" ✅ Allowed (code mode)");
|
||||
println!("\n System tools: Bash");
|
||||
println!(" ✅ Allowed (code mode)");
|
||||
}
|
||||
}
|
||||
}
|
||||
"/cost" => {
|
||||
println!("\n💰 Token Usage & Timing:");
|
||||
println!(" Est. Tokens: ~{}", stats.estimated_tokens);
|
||||
println!(" Total Time: {}", agent_core::SessionStats::format_duration(stats.total_duration));
|
||||
if stats.total_messages > 0 {
|
||||
let avg_time = stats.total_duration / stats.total_messages as u32;
|
||||
println!(" Avg/Message: {}", agent_core::SessionStats::format_duration(avg_time));
|
||||
}
|
||||
println!("\n Note: Ollama is free - no cost incurred!");
|
||||
}
|
||||
"/history" => {
|
||||
println!("\n📜 Conversation History:");
|
||||
if history.user_prompts.is_empty() {
|
||||
println!(" (No messages yet)");
|
||||
} else {
|
||||
for (i, (user, assistant)) in history.user_prompts.iter()
|
||||
.zip(history.assistant_responses.iter()).enumerate() {
|
||||
println!("\n [{}] User: {}", i + 1, user);
|
||||
println!(" Assistant: {}...",
|
||||
assistant.chars().take(100).collect::<String>());
|
||||
}
|
||||
}
|
||||
if !history.tool_calls.is_empty() {
|
||||
println!("\n Tool Calls: {}", history.tool_calls.len());
|
||||
}
|
||||
}
|
||||
"/checkpoint" => {
|
||||
let checkpoint_id = format!("checkpoint-{}",
|
||||
SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs()
|
||||
);
|
||||
match checkpoint_mgr.save_checkpoint(
|
||||
checkpoint_id.clone(),
|
||||
stats.clone(),
|
||||
&history,
|
||||
) {
|
||||
Ok(checkpoint) => {
|
||||
println!("\n💾 Checkpoint saved: {}", checkpoint_id);
|
||||
if !checkpoint.file_diffs.is_empty() {
|
||||
println!(" Files tracked: {}", checkpoint.file_diffs.len());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to save checkpoint: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/checkpoints" => {
|
||||
match checkpoint_mgr.list_checkpoints() {
|
||||
Ok(checkpoints) => {
|
||||
if checkpoints.is_empty() {
|
||||
println!("\n📋 No checkpoints saved yet");
|
||||
} else {
|
||||
println!("\n📋 Saved Checkpoints:");
|
||||
for (i, cp_id) in checkpoints.iter().enumerate() {
|
||||
println!(" [{}] {}", i + 1, cp_id);
|
||||
}
|
||||
println!("\n Use /rewind <id> to restore");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to list checkpoints: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
"/clear" => {
|
||||
history.clear();
|
||||
stats = agent_core::SessionStats::new();
|
||||
println!("\n🗑️ Session history cleared!");
|
||||
}
|
||||
"/plugins" => {
|
||||
let plugins = app_context.plugin_manager.plugins();
|
||||
if plugins.is_empty() {
|
||||
println!("\n📦 No plugins loaded");
|
||||
println!(" Place plugins in:");
|
||||
println!(" - ~/.config/owlen/plugins (user plugins)");
|
||||
println!(" - .owlen/plugins (project plugins)");
|
||||
} else {
|
||||
println!("\n📦 Loaded Plugins:");
|
||||
for plugin in plugins {
|
||||
println!("\n {} v{}", plugin.manifest.name, plugin.manifest.version);
|
||||
if let Some(desc) = &plugin.manifest.description {
|
||||
println!(" {}", desc);
|
||||
}
|
||||
if let Some(author) = &plugin.manifest.author {
|
||||
println!(" Author: {}", author);
|
||||
}
|
||||
|
||||
let commands = plugin.all_command_names();
|
||||
if !commands.is_empty() {
|
||||
println!(" Commands: {}", commands.join(", "));
|
||||
}
|
||||
|
||||
let agents = plugin.all_agent_names();
|
||||
if !agents.is_empty() {
|
||||
println!(" Agents: {}", agents.join(", "));
|
||||
}
|
||||
|
||||
let skills = plugin.all_skill_names();
|
||||
if !skills.is_empty() {
|
||||
println!(" Skills: {}", skills.join(", "));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"/exit" => {
|
||||
println!("\n👋 Goodbye!");
|
||||
break;
|
||||
}
|
||||
cmd if cmd.starts_with("/rewind ") => {
|
||||
let checkpoint_id = cmd.strip_prefix("/rewind ").unwrap().trim();
|
||||
match checkpoint_mgr.rewind_to(checkpoint_id) {
|
||||
Ok(restored_files) => {
|
||||
println!("\n⏪ Rewound to checkpoint: {}", checkpoint_id);
|
||||
if !restored_files.is_empty() {
|
||||
println!(" Restored files:");
|
||||
for file in restored_files {
|
||||
println!(" - {}", file.display());
|
||||
}
|
||||
}
|
||||
// Load the checkpoint to restore history and stats
|
||||
if let Ok(checkpoint) = checkpoint_mgr.load_checkpoint(checkpoint_id) {
|
||||
stats = checkpoint.stats;
|
||||
history.user_prompts = checkpoint.user_prompts;
|
||||
history.assistant_responses = checkpoint.assistant_responses;
|
||||
history.tool_calls = checkpoint.tool_calls;
|
||||
println!(" Session state restored");
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Failed to rewind: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
println!("\n❌ Unknown command: {}", input);
|
||||
println!(" Type /help for available commands");
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Regular message - run through agent loop
|
||||
history.add_user_message(input.to_string());
|
||||
let start = SystemTime::now();
|
||||
|
||||
let ctx = agent_core::ToolContext::new();
|
||||
match agent_core::run_agent_loop(&client, input, &opts, &perms, &ctx).await {
|
||||
Ok(response) => {
|
||||
println!("\n{}", response);
|
||||
history.add_assistant_message(response.clone());
|
||||
|
||||
// Update stats
|
||||
let duration = start.elapsed().unwrap_or_default();
|
||||
let tokens = (input.len() + response.len()) / 4; // Rough estimate
|
||||
stats.record_message(tokens, duration);
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\n❌ Error: {}", e);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Non-interactive mode - process single prompt
|
||||
let prompt = args.prompt.join(" ");
|
||||
let start_time = SystemTime::now();
|
||||
|
||||
// Handle different output formats
|
||||
let ctx = agent_core::ToolContext::new();
|
||||
match output_format {
|
||||
OutputFormat::Text => {
|
||||
// Text format: Use agent orchestrator with tool calling
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
println!("{}", response);
|
||||
}
|
||||
OutputFormat::Json => {
|
||||
// JSON format: Use agent loop and output as JSON
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
|
||||
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
|
||||
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
|
||||
|
||||
let output = SessionOutput {
|
||||
session_id,
|
||||
messages: vec![
|
||||
serde_json::json!({"role": "user", "content": prompt}),
|
||||
serde_json::json!({"role": "assistant", "content": response}),
|
||||
],
|
||||
stats: Stats {
|
||||
total_tokens: estimated_tokens,
|
||||
prompt_tokens: Some((prompt.len() / 4) as u64),
|
||||
completion_tokens: Some((response.len() / 4) as u64),
|
||||
duration_ms,
|
||||
},
|
||||
result: None,
|
||||
tool: None,
|
||||
};
|
||||
|
||||
println!("{}", serde_json::to_string(&output)?);
|
||||
}
|
||||
OutputFormat::StreamJson => {
|
||||
// Stream-JSON format: emit session_start, response, and session_end
|
||||
let session_start = StreamEvent {
|
||||
event_type: "session_start".to_string(),
|
||||
session_id: Some(session_id.clone()),
|
||||
content: None,
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&session_start)?);
|
||||
|
||||
let response = agent_core::run_agent_loop(&client, &prompt, &opts, &perms, &ctx).await?;
|
||||
|
||||
let chunk_event = StreamEvent {
|
||||
event_type: "chunk".to_string(),
|
||||
session_id: None,
|
||||
content: Some(response.clone()),
|
||||
stats: None,
|
||||
};
|
||||
println!("{}", serde_json::to_string(&chunk_event)?);
|
||||
|
||||
let duration_ms = start_time.elapsed().unwrap().as_millis() as u64;
|
||||
let estimated_tokens = ((prompt.len() + response.len()) / 4) as u64;
|
||||
|
||||
let session_end = StreamEvent {
|
||||
event_type: "session_end".to_string(),
|
||||
session_id: None,
|
||||
content: None,
|
||||
stats: Some(Stats {
|
||||
total_tokens: estimated_tokens,
|
||||
prompt_tokens: Some((prompt.len() / 4) as u64),
|
||||
completion_tokens: Some((response.len() / 4) as u64),
|
||||
duration_ms,
|
||||
}),
|
||||
};
|
||||
println!("{}", serde_json::to_string(&session_end)?);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
34
crates/app/cli/tests/chat_stream.rs
Normal file
34
crates/app/cli/tests/chat_stream.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use assert_cmd::Command;
|
||||
use httpmock::prelude::*;
|
||||
use predicates::prelude::PredicateBooleanExt;
|
||||
|
||||
#[tokio::test]
|
||||
async fn headless_streams_ndjson() {
|
||||
let server = MockServer::start_async().await;
|
||||
|
||||
let response = concat!(
|
||||
r#"{"message":{"role":"assistant","content":"Hel"}}"#,"\n",
|
||||
r#"{"message":{"role":"assistant","content":"lo"}}"#,"\n",
|
||||
r#"{"done":true}"#,"\n",
|
||||
);
|
||||
|
||||
// The CLI includes tools in the request, so we need to match any request to /api/chat
|
||||
// instead of matching exact body (which includes tool definitions)
|
||||
let _m = server.mock(|when, then| {
|
||||
when.method(POST)
|
||||
.path("/api/chat");
|
||||
then.status(200)
|
||||
.header("content-type", "application/x-ndjson")
|
||||
.body(response);
|
||||
});
|
||||
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--ollama-url").arg(server.base_url())
|
||||
.arg("--model").arg("qwen2.5")
|
||||
.arg("--print")
|
||||
.arg("hello");
|
||||
|
||||
cmd.assert()
|
||||
.success()
|
||||
.stdout(predicates::str::contains("Hello").count(1).or(predicates::str::contains("Hel").and(predicates::str::contains("lo"))));
|
||||
}
|
||||
145
crates/app/cli/tests/headless.rs
Normal file
145
crates/app/cli/tests/headless.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use assert_cmd::Command;
|
||||
use serde_json::Value;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn print_json_has_session_id_and_stats() {
|
||||
let mut cmd = Command::cargo_bin("owlen").unwrap();
|
||||
cmd.arg("--output-format")
|
||||
.arg("json")
|
||||
.arg("Say hello");
|
||||
|
||||
let output = cmd.assert().success();
|
||||
let stdout = String::from_utf8_lossy(&output.get_output().stdout);
|
||||
|
||||
// Parse JSON output
|
||||
let json: Value = serde_json::from_str(&stdout).expect("Output should be valid JSON");
|
||||
|
||||
// Verify session_id exists
|
||||
assert!(json.get("session_id").is_some(), "JSON output should have session_id");
|
||||
let session_id = json["session_id"].as_str().unwrap();
|
||||
assert!(!session_id.is_empty(), "session_id should not be empty");
|
||||
|
||||
// Verify stats exist
|
||||
assert!(json.get("stats").is_some(), "JSON output should have stats");
|
||||
let stats = &json["stats"];
|
||||
|
||||
// Check for token counts
|
||||
assert!(stats.get("total_tokens").is_some(), "stats should have total_tokens");
|
||||
|
||||
// Check for messages
|
||||
assert!(json.get("messages").is_some(), "JSON output should have messages");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_json_sequence_is_well_formed() {
|
||||
let mut cmd = Command::cargo_bin("owlen").unwrap();
|
||||
cmd.arg("--output-format")
|
||||
.arg("stream-json")
|
||||
.arg("Say hello");
|
||||
|
||||
let output = cmd.assert().success();
|
||||
let stdout = String::from_utf8_lossy(&output.get_output().stdout);
|
||||
|
||||
// Stream-JSON is NDJSON - each line should be valid JSON
|
||||
let lines: Vec<&str> = stdout.lines().filter(|l| !l.is_empty()).collect();
|
||||
|
||||
assert!(!lines.is_empty(), "Stream-JSON should produce at least one event");
|
||||
|
||||
// Each line should be valid JSON
|
||||
for (i, line) in lines.iter().enumerate() {
|
||||
let json: Value = serde_json::from_str(line)
|
||||
.expect(&format!("Line {} should be valid JSON: {}", i, line));
|
||||
|
||||
// Each event should have a type
|
||||
assert!(json.get("type").is_some(), "Event should have a type field");
|
||||
}
|
||||
|
||||
// First event should be session_start
|
||||
let first: Value = serde_json::from_str(lines[0]).unwrap();
|
||||
assert_eq!(first["type"].as_str().unwrap(), "session_start");
|
||||
assert!(first.get("session_id").is_some());
|
||||
|
||||
// Last event should be session_end or complete
|
||||
let last: Value = serde_json::from_str(lines[lines.len() - 1]).unwrap();
|
||||
let last_type = last["type"].as_str().unwrap();
|
||||
assert!(
|
||||
last_type == "session_end" || last_type == "complete",
|
||||
"Last event should be session_end or complete, got: {}",
|
||||
last_type
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn text_format_is_default() {
|
||||
let mut cmd = Command::cargo_bin("owlen").unwrap();
|
||||
cmd.arg("Say hello");
|
||||
|
||||
let output = cmd.assert().success();
|
||||
let stdout = String::from_utf8_lossy(&output.get_output().stdout);
|
||||
|
||||
// Text format should not be JSON
|
||||
assert!(serde_json::from_str::<Value>(&stdout).is_err(),
|
||||
"Default output should be text, not JSON");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn json_format_with_tool_execution() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
fs::write(&file, "hello world").unwrap();
|
||||
|
||||
let mut cmd = Command::cargo_bin("owlen").unwrap();
|
||||
cmd.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("--output-format")
|
||||
.arg("json")
|
||||
.arg("read")
|
||||
.arg(file.to_str().unwrap());
|
||||
|
||||
let output = cmd.assert().success();
|
||||
let stdout = String::from_utf8_lossy(&output.get_output().stdout);
|
||||
|
||||
let json: Value = serde_json::from_str(&stdout).expect("Output should be valid JSON");
|
||||
|
||||
// Should have result
|
||||
assert!(json.get("result").is_some());
|
||||
|
||||
// Should have tool info
|
||||
assert!(json.get("tool").is_some());
|
||||
assert_eq!(json["tool"].as_str().unwrap(), "Read");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stream_json_includes_chunk_events() {
|
||||
let mut cmd = Command::cargo_bin("owlen").unwrap();
|
||||
cmd.arg("--output-format")
|
||||
.arg("stream-json")
|
||||
.arg("Say hello");
|
||||
|
||||
let output = cmd.assert().success();
|
||||
let stdout = String::from_utf8_lossy(&output.get_output().stdout);
|
||||
|
||||
let lines: Vec<&str> = stdout.lines().filter(|l| !l.is_empty()).collect();
|
||||
|
||||
// Should have chunk events between session_start and session_end
|
||||
let chunk_events: Vec<&str> = lines.iter()
|
||||
.filter(|line| {
|
||||
if let Ok(json) = serde_json::from_str::<Value>(line) {
|
||||
json["type"].as_str() == Some("chunk")
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.copied()
|
||||
.collect();
|
||||
|
||||
assert!(!chunk_events.is_empty(), "Should have at least one chunk event");
|
||||
|
||||
// Each chunk should have content
|
||||
for chunk_line in chunk_events {
|
||||
let chunk: Value = serde_json::from_str(chunk_line).unwrap();
|
||||
assert!(chunk.get("content").is_some(), "Chunk should have content");
|
||||
}
|
||||
}
|
||||
255
crates/app/cli/tests/permissions.rs
Normal file
255
crates/app/cli/tests/permissions.rs
Normal file
@@ -0,0 +1,255 @@
|
||||
use assert_cmd::Command;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn plan_mode_allows_read_operations() {
|
||||
// Create a temp file to read
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
fs::write(&file, "hello world").unwrap();
|
||||
|
||||
// Read operation should work in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("read").arg(file.to_str().unwrap());
|
||||
cmd.assert().success().stdout("hello world\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_allows_glob_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
fs::write(dir.path().join("a.txt"), "test").unwrap();
|
||||
fs::write(dir.path().join("b.txt"), "test").unwrap();
|
||||
|
||||
let pattern = format!("{}/*.txt", dir.path().display());
|
||||
|
||||
// Glob operation should work in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("glob").arg(&pattern);
|
||||
cmd.assert().success();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_allows_grep_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
fs::write(dir.path().join("test.txt"), "hello world\nfoo bar").unwrap();
|
||||
|
||||
// Grep operation should work in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("grep").arg(dir.path().to_str().unwrap()).arg("hello");
|
||||
cmd.assert().success();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mode_override_via_cli_flag() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
fs::write(&file, "content").unwrap();
|
||||
|
||||
// Test with --mode code (should also allow read)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("read")
|
||||
.arg(file.to_str().unwrap());
|
||||
cmd.assert().success().stdout("content\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_blocks_write_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("new.txt");
|
||||
|
||||
// Write operation should be blocked in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("write").arg(file.to_str().unwrap()).arg("content");
|
||||
cmd.assert().failure();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_blocks_edit_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
fs::write(&file, "old content").unwrap();
|
||||
|
||||
// Edit operation should be blocked in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("edit")
|
||||
.arg(file.to_str().unwrap())
|
||||
.arg("old")
|
||||
.arg("new");
|
||||
cmd.assert().failure();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_edits_mode_allows_write() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("new.txt");
|
||||
|
||||
// Write operation should work in acceptEdits mode
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("acceptEdits")
|
||||
.arg("write")
|
||||
.arg(file.to_str().unwrap())
|
||||
.arg("new content");
|
||||
cmd.assert().success();
|
||||
|
||||
// Verify file was written
|
||||
assert_eq!(fs::read_to_string(&file).unwrap(), "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn accept_edits_mode_allows_edit() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
fs::write(&file, "line 1\nline 2\nline 3").unwrap();
|
||||
|
||||
// Edit operation should work in acceptEdits mode
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("acceptEdits")
|
||||
.arg("edit")
|
||||
.arg(file.to_str().unwrap())
|
||||
.arg("line 2")
|
||||
.arg("modified line");
|
||||
cmd.assert().success();
|
||||
|
||||
// Verify file was edited
|
||||
assert_eq!(
|
||||
fs::read_to_string(&file).unwrap(),
|
||||
"line 1\nmodified line\nline 3"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_mode_allows_all_operations() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file = dir.path().join("test.txt");
|
||||
|
||||
// Write in code mode
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("write")
|
||||
.arg(file.to_str().unwrap())
|
||||
.arg("initial content");
|
||||
cmd.assert().success();
|
||||
|
||||
// Edit in code mode
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("edit")
|
||||
.arg(file.to_str().unwrap())
|
||||
.arg("initial")
|
||||
.arg("modified");
|
||||
cmd.assert().success();
|
||||
|
||||
assert_eq!(fs::read_to_string(&file).unwrap(), "modified content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_mode_blocks_bash_operations() {
|
||||
// Bash operation should be blocked in plan mode (default)
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("bash").arg("echo hello");
|
||||
cmd.assert().failure();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn code_mode_allows_bash() {
|
||||
// Bash operation should work in code mode
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode").arg("code").arg("bash").arg("echo hello");
|
||||
cmd.assert().success().stdout("hello\n");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bash_command_timeout_works() {
|
||||
// Test that timeout works
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("bash")
|
||||
.arg("sleep 10")
|
||||
.arg("--timeout")
|
||||
.arg("1000");
|
||||
cmd.assert().failure();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slash_command_works() {
|
||||
// Create .owlen/commands directory in temp dir
|
||||
let dir = tempdir().unwrap();
|
||||
let commands_dir = dir.path().join(".owlen/commands");
|
||||
fs::create_dir_all(&commands_dir).unwrap();
|
||||
|
||||
// Create a test slash command
|
||||
let command_content = r#"---
|
||||
description: "Test command"
|
||||
---
|
||||
Hello from slash command!
|
||||
Args: $ARGUMENTS
|
||||
First: $1
|
||||
"#;
|
||||
let command_file = commands_dir.join("test.md");
|
||||
fs::write(&command_file, command_content).unwrap();
|
||||
|
||||
// Execute slash command with args from the temp directory
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.current_dir(dir.path())
|
||||
.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("slash")
|
||||
.arg("test")
|
||||
.arg("arg1");
|
||||
|
||||
cmd.assert()
|
||||
.success()
|
||||
.stdout(predicates::str::contains("Hello from slash command!"))
|
||||
.stdout(predicates::str::contains("Args: arg1"))
|
||||
.stdout(predicates::str::contains("First: arg1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slash_command_file_refs() {
|
||||
let dir = tempdir().unwrap();
|
||||
let commands_dir = dir.path().join(".owlen/commands");
|
||||
fs::create_dir_all(&commands_dir).unwrap();
|
||||
|
||||
// Create a file to reference
|
||||
let data_file = dir.path().join("data.txt");
|
||||
fs::write(&data_file, "Referenced content").unwrap();
|
||||
|
||||
// Create slash command with file reference
|
||||
let command_content = format!("File content: @{}", data_file.display());
|
||||
fs::write(commands_dir.join("reftest.md"), command_content).unwrap();
|
||||
|
||||
// Execute slash command
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.current_dir(dir.path())
|
||||
.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("slash")
|
||||
.arg("reftest");
|
||||
|
||||
cmd.assert()
|
||||
.success()
|
||||
.stdout(predicates::str::contains("Referenced content"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slash_command_not_found() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
// Try to execute non-existent slash command
|
||||
let mut cmd = Command::new(assert_cmd::cargo::cargo_bin!("owlen"));
|
||||
cmd.current_dir(dir.path())
|
||||
.arg("--mode")
|
||||
.arg("code")
|
||||
.arg("slash")
|
||||
.arg("nonexistent");
|
||||
|
||||
cmd.assert().failure();
|
||||
}
|
||||
27
crates/app/ui/Cargo.toml
Normal file
27
crates/app/ui/Cargo.toml
Normal file
@@ -0,0 +1,27 @@
|
||||
[package]
|
||||
name = "ui"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
color-eyre = "0.6"
|
||||
crossterm = { version = "0.28", features = ["event-stream"] }
|
||||
ratatui = "0.28"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures = "0.3"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
unicode-width = "0.2"
|
||||
textwrap = "0.16"
|
||||
syntect = { version = "5.0", default-features = false, features = ["default-syntaxes", "default-themes", "regex-onig"] }
|
||||
pulldown-cmark = "0.11"
|
||||
|
||||
# Internal dependencies
|
||||
agent-core = { path = "../../core/agent" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
llm-core = { path = "../../llm/core" }
|
||||
llm-ollama = { path = "../../llm/ollama" }
|
||||
config-agent = { path = "../../platform/config" }
|
||||
tools-todo = { path = "../../tools/todo" }
|
||||
1101
crates/app/ui/src/app.rs
Normal file
1101
crates/app/ui/src/app.rs
Normal file
File diff suppressed because it is too large
Load Diff
226
crates/app/ui/src/completions.rs
Normal file
226
crates/app/ui/src/completions.rs
Normal file
@@ -0,0 +1,226 @@
|
||||
//! Command completion engine for the TUI
|
||||
//!
|
||||
//! Provides Tab-completion for slash commands, file paths, and tool names.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// A single completion suggestion
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Completion {
|
||||
/// The text to insert
|
||||
pub text: String,
|
||||
/// Description of what this completion does
|
||||
pub description: String,
|
||||
/// Source of the completion (e.g., "builtin", "plugin:name")
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
/// Information about a command for completion purposes
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CommandInfo {
|
||||
/// Command name (without leading /)
|
||||
pub name: String,
|
||||
/// Command description
|
||||
pub description: String,
|
||||
/// Source of the command
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
impl CommandInfo {
|
||||
pub fn new(name: &str, description: &str, source: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
description: description.to_string(),
|
||||
source: source.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Completion engine for the TUI
|
||||
pub struct CompletionEngine {
|
||||
/// Available commands
|
||||
commands: Vec<CommandInfo>,
|
||||
}
|
||||
|
||||
impl Default for CompletionEngine {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompletionEngine {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
commands: Self::builtin_commands(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get built-in commands
|
||||
fn builtin_commands() -> Vec<CommandInfo> {
|
||||
vec![
|
||||
CommandInfo::new("help", "Show available commands and help", "builtin"),
|
||||
CommandInfo::new("clear", "Clear the screen", "builtin"),
|
||||
CommandInfo::new("mcp", "List MCP servers and their tools", "builtin"),
|
||||
CommandInfo::new("hooks", "Show loaded hooks", "builtin"),
|
||||
CommandInfo::new("compact", "Compact conversation context", "builtin"),
|
||||
CommandInfo::new("mode", "Switch permission mode (plan/edit/code)", "builtin"),
|
||||
CommandInfo::new("provider", "Switch LLM provider", "builtin"),
|
||||
CommandInfo::new("model", "Switch LLM model", "builtin"),
|
||||
CommandInfo::new("checkpoint", "Create a checkpoint", "builtin"),
|
||||
CommandInfo::new("rewind", "Rewind to a checkpoint", "builtin"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Add commands from plugins
|
||||
pub fn add_plugin_commands(&mut self, plugin_name: &str, commands: Vec<CommandInfo>) {
|
||||
for mut cmd in commands {
|
||||
cmd.source = format!("plugin:{}", plugin_name);
|
||||
self.commands.push(cmd);
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single command
|
||||
pub fn add_command(&mut self, command: CommandInfo) {
|
||||
self.commands.push(command);
|
||||
}
|
||||
|
||||
/// Get completions for the given input
|
||||
pub fn complete(&self, input: &str) -> Vec<Completion> {
|
||||
if input.starts_with('/') {
|
||||
self.complete_command(&input[1..])
|
||||
} else if input.starts_with('@') {
|
||||
self.complete_file_path(&input[1..])
|
||||
} else {
|
||||
vec![]
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete a slash command
|
||||
fn complete_command(&self, partial: &str) -> Vec<Completion> {
|
||||
let partial_lower = partial.to_lowercase();
|
||||
|
||||
self.commands
|
||||
.iter()
|
||||
.filter(|cmd| {
|
||||
// Match if name starts with partial, or contains partial (fuzzy)
|
||||
cmd.name.to_lowercase().starts_with(&partial_lower)
|
||||
|| (partial.len() >= 2 && cmd.name.to_lowercase().contains(&partial_lower))
|
||||
})
|
||||
.map(|cmd| Completion {
|
||||
text: format!("/{}", cmd.name),
|
||||
description: cmd.description.clone(),
|
||||
source: cmd.source.clone(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Complete a file path
|
||||
fn complete_file_path(&self, partial: &str) -> Vec<Completion> {
|
||||
let path = Path::new(partial);
|
||||
|
||||
// Get the directory to search and the prefix to match
|
||||
let (dir, prefix) = if partial.ends_with('/') || partial.is_empty() {
|
||||
(partial, "")
|
||||
} else {
|
||||
let parent = path.parent().map(|p| p.to_str().unwrap_or("")).unwrap_or("");
|
||||
let file_name = path.file_name().and_then(|f| f.to_str()).unwrap_or("");
|
||||
(parent, file_name)
|
||||
};
|
||||
|
||||
// Search directory
|
||||
let search_dir = if dir.is_empty() { "." } else { dir };
|
||||
|
||||
match std::fs::read_dir(search_dir) {
|
||||
Ok(entries) => {
|
||||
entries
|
||||
.filter_map(|entry| entry.ok())
|
||||
.filter(|entry| {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
// Skip hidden files unless user started typing with .
|
||||
if !prefix.starts_with('.') && name_str.starts_with('.') {
|
||||
return false;
|
||||
}
|
||||
name_str.to_lowercase().starts_with(&prefix.to_lowercase())
|
||||
})
|
||||
.map(|entry| {
|
||||
let name = entry.file_name();
|
||||
let name_str = name.to_string_lossy();
|
||||
let is_dir = entry.file_type().map(|t| t.is_dir()).unwrap_or(false);
|
||||
|
||||
let full_path = if dir.is_empty() {
|
||||
name_str.to_string()
|
||||
} else if dir.ends_with('/') {
|
||||
format!("{}{}", dir, name_str)
|
||||
} else {
|
||||
format!("{}/{}", dir, name_str)
|
||||
};
|
||||
|
||||
Completion {
|
||||
text: format!("@{}{}", full_path, if is_dir { "/" } else { "" }),
|
||||
description: if is_dir { "Directory".to_string() } else { "File".to_string() },
|
||||
source: "filesystem".to_string(),
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
Err(_) => vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all commands (for /help display)
|
||||
pub fn all_commands(&self) -> &[CommandInfo] {
|
||||
&self.commands
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_exact() {
|
||||
let engine = CompletionEngine::new();
|
||||
let completions = engine.complete("/help");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_partial() {
|
||||
let engine = CompletionEngine::new();
|
||||
let completions = engine.complete("/hel");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_completion_fuzzy() {
|
||||
let engine = CompletionEngine::new();
|
||||
// "cle" should match "clear"
|
||||
let completions = engine.complete("/cle");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.text == "/clear"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_info() {
|
||||
let info = CommandInfo::new("test", "A test command", "builtin");
|
||||
assert_eq!(info.name, "test");
|
||||
assert_eq!(info.description, "A test command");
|
||||
assert_eq!(info.source, "builtin");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_add_plugin_commands() {
|
||||
let mut engine = CompletionEngine::new();
|
||||
let plugin_cmds = vec![
|
||||
CommandInfo::new("custom", "A custom command", ""),
|
||||
];
|
||||
engine.add_plugin_commands("my-plugin", plugin_cmds);
|
||||
|
||||
let completions = engine.complete("/custom");
|
||||
assert!(!completions.is_empty());
|
||||
assert!(completions.iter().any(|c| c.source == "plugin:my-plugin"));
|
||||
}
|
||||
}
|
||||
377
crates/app/ui/src/components/autocomplete.rs
Normal file
377
crates/app/ui/src/components/autocomplete.rs
Normal file
@@ -0,0 +1,377 @@
|
||||
//! Command autocomplete dropdown component
|
||||
//!
|
||||
//! Displays inline autocomplete suggestions when user types `/`.
|
||||
//! Supports fuzzy filtering as user types.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// An autocomplete option
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct AutocompleteOption {
|
||||
/// The trigger text (command name without /)
|
||||
pub trigger: String,
|
||||
/// Display text (e.g., "/model [name]")
|
||||
pub display: String,
|
||||
/// Short description
|
||||
pub description: String,
|
||||
/// Has submenu/subcommands
|
||||
pub has_submenu: bool,
|
||||
}
|
||||
|
||||
impl AutocompleteOption {
|
||||
pub fn new(trigger: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{}", trigger),
|
||||
description: description.to_string(),
|
||||
has_submenu: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_args(trigger: &str, args: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{} {}", trigger, args),
|
||||
description: description.to_string(),
|
||||
has_submenu: false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_submenu(trigger: &str, description: &str) -> Self {
|
||||
Self {
|
||||
trigger: trigger.to_string(),
|
||||
display: format!("/{}", trigger),
|
||||
description: description.to_string(),
|
||||
has_submenu: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Default command options
|
||||
fn default_options() -> Vec<AutocompleteOption> {
|
||||
vec![
|
||||
AutocompleteOption::new("help", "Show help"),
|
||||
AutocompleteOption::new("status", "Session info"),
|
||||
AutocompleteOption::with_args("model", "[name]", "Switch model"),
|
||||
AutocompleteOption::with_args("provider", "[name]", "Switch provider"),
|
||||
AutocompleteOption::new("history", "View history"),
|
||||
AutocompleteOption::new("checkpoint", "Save state"),
|
||||
AutocompleteOption::new("checkpoints", "List checkpoints"),
|
||||
AutocompleteOption::with_args("rewind", "[id]", "Restore"),
|
||||
AutocompleteOption::new("cost", "Token usage"),
|
||||
AutocompleteOption::new("clear", "Clear chat"),
|
||||
AutocompleteOption::new("compact", "Compact context"),
|
||||
AutocompleteOption::new("permissions", "Permission mode"),
|
||||
AutocompleteOption::new("themes", "List themes"),
|
||||
AutocompleteOption::with_args("theme", "[name]", "Switch theme"),
|
||||
AutocompleteOption::new("exit", "Exit"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Autocomplete dropdown component
|
||||
pub struct Autocomplete {
|
||||
options: Vec<AutocompleteOption>,
|
||||
filtered: Vec<usize>, // indices into options
|
||||
selected: usize,
|
||||
visible: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl Autocomplete {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
let options = default_options();
|
||||
let filtered: Vec<usize> = (0..options.len()).collect();
|
||||
|
||||
Self {
|
||||
options,
|
||||
filtered,
|
||||
selected: 0,
|
||||
visible: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Show autocomplete and reset filter
|
||||
pub fn show(&mut self) {
|
||||
self.visible = true;
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
self.selected = 0;
|
||||
}
|
||||
|
||||
/// Hide autocomplete
|
||||
pub fn hide(&mut self) {
|
||||
self.visible = false;
|
||||
}
|
||||
|
||||
/// Check if visible
|
||||
pub fn is_visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
/// Update filter based on current input (text after /)
|
||||
pub fn update_filter(&mut self, query: &str) {
|
||||
if query.is_empty() {
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
} else {
|
||||
let query_lower = query.to_lowercase();
|
||||
self.filtered = self.options
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, opt)| {
|
||||
// Fuzzy match: check if query chars appear in order
|
||||
fuzzy_match(&opt.trigger.to_lowercase(), &query_lower)
|
||||
})
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Reset selection if it's out of bounds
|
||||
if self.selected >= self.filtered.len() {
|
||||
self.selected = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Select next option
|
||||
pub fn select_next(&mut self) {
|
||||
if !self.filtered.is_empty() {
|
||||
self.selected = (self.selected + 1) % self.filtered.len();
|
||||
}
|
||||
}
|
||||
|
||||
/// Select previous option
|
||||
pub fn select_prev(&mut self) {
|
||||
if !self.filtered.is_empty() {
|
||||
self.selected = if self.selected == 0 {
|
||||
self.filtered.len() - 1
|
||||
} else {
|
||||
self.selected - 1
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the currently selected option's trigger
|
||||
pub fn confirm(&self) -> Option<String> {
|
||||
if self.filtered.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let idx = self.filtered[self.selected];
|
||||
Some(format!("/{}", self.options[idx].trigger))
|
||||
}
|
||||
|
||||
/// Handle key input, returns Some(command) if confirmed
|
||||
///
|
||||
/// Key behavior:
|
||||
/// - Tab: Confirm selection and insert into input
|
||||
/// - Down/Up: Navigate options
|
||||
/// - Enter: Pass through to submit (NotHandled)
|
||||
/// - Esc: Cancel autocomplete
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> AutocompleteResult {
|
||||
if !self.visible {
|
||||
return AutocompleteResult::NotHandled;
|
||||
}
|
||||
|
||||
match key.code {
|
||||
KeyCode::Tab => {
|
||||
// Tab confirms and inserts the selected command
|
||||
if let Some(cmd) = self.confirm() {
|
||||
self.hide();
|
||||
AutocompleteResult::Confirmed(cmd)
|
||||
} else {
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
}
|
||||
KeyCode::Down => {
|
||||
self.select_next();
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
KeyCode::BackTab | KeyCode::Up => {
|
||||
self.select_prev();
|
||||
AutocompleteResult::Handled
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
// Enter should submit the message, not confirm autocomplete
|
||||
// Hide autocomplete and let Enter pass through
|
||||
self.hide();
|
||||
AutocompleteResult::NotHandled
|
||||
}
|
||||
KeyCode::Esc => {
|
||||
self.hide();
|
||||
AutocompleteResult::Cancelled
|
||||
}
|
||||
_ => AutocompleteResult::NotHandled,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Add custom options (from plugins)
|
||||
pub fn add_options(&mut self, options: Vec<AutocompleteOption>) {
|
||||
self.options.extend(options);
|
||||
// Re-filter with all options
|
||||
self.filtered = (0..self.options.len()).collect();
|
||||
}
|
||||
|
||||
/// Render the autocomplete dropdown above the input line
|
||||
pub fn render(&self, frame: &mut Frame, input_area: Rect) {
|
||||
if !self.visible || self.filtered.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate dropdown dimensions
|
||||
let max_visible = 8.min(self.filtered.len());
|
||||
let width = 40.min(input_area.width.saturating_sub(4));
|
||||
let height = (max_visible + 2) as u16; // +2 for borders
|
||||
|
||||
// Position above input, left-aligned with some padding
|
||||
let x = input_area.x + 2;
|
||||
let y = input_area.y.saturating_sub(height);
|
||||
|
||||
let dropdown_area = Rect::new(x, y, width, height);
|
||||
|
||||
// Clear area behind dropdown
|
||||
frame.render_widget(Clear, dropdown_area);
|
||||
|
||||
// Build option lines
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
for (display_idx, &opt_idx) in self.filtered.iter().take(max_visible).enumerate() {
|
||||
let opt = &self.options[opt_idx];
|
||||
let is_selected = display_idx == self.selected;
|
||||
|
||||
let style = if is_selected {
|
||||
self.theme.selected
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
|
||||
let mut spans = vec![
|
||||
Span::styled(" ", style),
|
||||
Span::styled("/", if is_selected { style } else { self.theme.cmd_slash }),
|
||||
Span::styled(&opt.trigger, if is_selected { style } else { self.theme.cmd_name }),
|
||||
];
|
||||
|
||||
// Submenu indicator
|
||||
if opt.has_submenu {
|
||||
spans.push(Span::styled(" >", if is_selected { style } else { self.theme.cmd_desc }));
|
||||
}
|
||||
|
||||
// Pad to fixed width for consistent selection highlighting
|
||||
let current_len: usize = spans.iter().map(|s| s.content.len()).sum();
|
||||
let padding = (width as usize).saturating_sub(current_len + 1);
|
||||
spans.push(Span::styled(" ".repeat(padding), style));
|
||||
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
// Show overflow indicator if needed
|
||||
if self.filtered.len() > max_visible {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!(" ... +{} more", self.filtered.len() - max_visible),
|
||||
self.theme.cmd_desc,
|
||||
)));
|
||||
}
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(self.theme.palette.border))
|
||||
.style(self.theme.overlay_bg);
|
||||
|
||||
let paragraph = Paragraph::new(lines).block(block);
|
||||
|
||||
frame.render_widget(paragraph, dropdown_area);
|
||||
}
|
||||
}
|
||||
|
||||
/// Result of handling autocomplete key
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AutocompleteResult {
|
||||
/// Key was not handled by autocomplete
|
||||
NotHandled,
|
||||
/// Key was handled, no action needed
|
||||
Handled,
|
||||
/// User confirmed selection, returns command string
|
||||
Confirmed(String),
|
||||
/// User cancelled autocomplete
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
/// Simple fuzzy match: check if query chars appear in order in text
|
||||
fn fuzzy_match(text: &str, query: &str) -> bool {
|
||||
let mut text_chars = text.chars().peekable();
|
||||
|
||||
for query_char in query.chars() {
|
||||
loop {
|
||||
match text_chars.next() {
|
||||
Some(c) if c == query_char => break,
|
||||
Some(_) => continue,
|
||||
None => return false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_fuzzy_match() {
|
||||
assert!(fuzzy_match("help", "h"));
|
||||
assert!(fuzzy_match("help", "he"));
|
||||
assert!(fuzzy_match("help", "hel"));
|
||||
assert!(fuzzy_match("help", "help"));
|
||||
assert!(fuzzy_match("help", "hp")); // fuzzy: h...p
|
||||
assert!(!fuzzy_match("help", "x"));
|
||||
assert!(!fuzzy_match("help", "helping")); // query longer than text
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_filter() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
|
||||
ac.update_filter("he");
|
||||
assert!(ac.filtered.len() < ac.options.len());
|
||||
|
||||
// Should match "help"
|
||||
assert!(ac.filtered.iter().any(|&i| ac.options[i].trigger == "help"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
ac.show();
|
||||
|
||||
assert_eq!(ac.selected, 0);
|
||||
ac.select_next();
|
||||
assert_eq!(ac.selected, 1);
|
||||
ac.select_prev();
|
||||
assert_eq!(ac.selected, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_autocomplete_confirm() {
|
||||
let theme = Theme::default();
|
||||
let mut ac = Autocomplete::new(theme);
|
||||
ac.show();
|
||||
|
||||
let cmd = ac.confirm();
|
||||
assert!(cmd.is_some());
|
||||
assert!(cmd.unwrap().starts_with("/"));
|
||||
}
|
||||
}
|
||||
468
crates/app/ui/src/components/chat_panel.rs
Normal file
468
crates/app/ui/src/components/chat_panel.rs
Normal file
@@ -0,0 +1,468 @@
|
||||
//! Borderless chat panel component
|
||||
//!
|
||||
//! Displays chat messages with proper indentation, timestamps,
|
||||
//! and streaming indicators. Uses whitespace instead of borders.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::{Modifier, Style},
|
||||
text::{Line, Span, Text},
|
||||
widgets::{Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
|
||||
Frame,
|
||||
};
|
||||
use std::time::SystemTime;
|
||||
|
||||
/// Chat message types
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ChatMessage {
|
||||
User(String),
|
||||
Assistant(String),
|
||||
ToolCall { name: String, args: String },
|
||||
ToolResult { success: bool, output: String },
|
||||
System(String),
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Get a timestamp for when the message was created (for display)
|
||||
pub fn timestamp_display() -> String {
|
||||
let now = SystemTime::now();
|
||||
let secs = now
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
let hours = (secs / 3600) % 24;
|
||||
let mins = (secs / 60) % 60;
|
||||
format!("{:02}:{:02}", hours, mins)
|
||||
}
|
||||
}
|
||||
|
||||
/// Message with metadata for display
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DisplayMessage {
|
||||
pub message: ChatMessage,
|
||||
pub timestamp: String,
|
||||
pub focused: bool,
|
||||
}
|
||||
|
||||
impl DisplayMessage {
|
||||
pub fn new(message: ChatMessage) -> Self {
|
||||
Self {
|
||||
message,
|
||||
timestamp: ChatMessage::timestamp_display(),
|
||||
focused: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Borderless chat panel
|
||||
pub struct ChatPanel {
|
||||
messages: Vec<DisplayMessage>,
|
||||
scroll_offset: usize,
|
||||
auto_scroll: bool,
|
||||
total_lines: usize,
|
||||
focused_index: Option<usize>,
|
||||
is_streaming: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl ChatPanel {
|
||||
/// Create new borderless chat panel
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
messages: Vec::new(),
|
||||
scroll_offset: 0,
|
||||
auto_scroll: true,
|
||||
total_lines: 0,
|
||||
focused_index: None,
|
||||
is_streaming: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a new message
|
||||
pub fn add_message(&mut self, message: ChatMessage) {
|
||||
self.messages.push(DisplayMessage::new(message));
|
||||
self.auto_scroll = true;
|
||||
self.is_streaming = false;
|
||||
}
|
||||
|
||||
/// Append content to the last assistant message, or create a new one
|
||||
pub fn append_to_assistant(&mut self, content: &str) {
|
||||
if let Some(DisplayMessage {
|
||||
message: ChatMessage::Assistant(last_content),
|
||||
..
|
||||
}) = self.messages.last_mut()
|
||||
{
|
||||
last_content.push_str(content);
|
||||
} else {
|
||||
self.messages.push(DisplayMessage::new(ChatMessage::Assistant(
|
||||
content.to_string(),
|
||||
)));
|
||||
}
|
||||
self.auto_scroll = true;
|
||||
self.is_streaming = true;
|
||||
}
|
||||
|
||||
/// Set streaming state
|
||||
pub fn set_streaming(&mut self, streaming: bool) {
|
||||
self.is_streaming = streaming;
|
||||
}
|
||||
|
||||
/// Scroll up
|
||||
pub fn scroll_up(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
|
||||
self.auto_scroll = false;
|
||||
}
|
||||
|
||||
/// Scroll down
|
||||
pub fn scroll_down(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_add(amount);
|
||||
let near_bottom_threshold = 5;
|
||||
if self.total_lines > 0 {
|
||||
let max_scroll = self.total_lines.saturating_sub(1);
|
||||
if self.scroll_offset.saturating_add(near_bottom_threshold) >= max_scroll {
|
||||
self.auto_scroll = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll to bottom
|
||||
pub fn scroll_to_bottom(&mut self) {
|
||||
self.scroll_offset = self.total_lines.saturating_sub(1);
|
||||
self.auto_scroll = true;
|
||||
}
|
||||
|
||||
/// Page up
|
||||
pub fn page_up(&mut self, page_size: usize) {
|
||||
self.scroll_up(page_size.saturating_sub(2));
|
||||
}
|
||||
|
||||
/// Page down
|
||||
pub fn page_down(&mut self, page_size: usize) {
|
||||
self.scroll_down(page_size.saturating_sub(2));
|
||||
}
|
||||
|
||||
/// Focus next message
|
||||
pub fn focus_next(&mut self) {
|
||||
if self.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.focused_index = Some(match self.focused_index {
|
||||
Some(i) if i + 1 < self.messages.len() => i + 1,
|
||||
Some(_) => 0,
|
||||
None => 0,
|
||||
});
|
||||
}
|
||||
|
||||
/// Focus previous message
|
||||
pub fn focus_previous(&mut self) {
|
||||
if self.messages.is_empty() {
|
||||
return;
|
||||
}
|
||||
self.focused_index = Some(match self.focused_index {
|
||||
Some(0) => self.messages.len() - 1,
|
||||
Some(i) => i - 1,
|
||||
None => self.messages.len() - 1,
|
||||
});
|
||||
}
|
||||
|
||||
/// Clear focus
|
||||
pub fn clear_focus(&mut self) {
|
||||
self.focused_index = None;
|
||||
}
|
||||
|
||||
/// Get focused message index
|
||||
pub fn focused_index(&self) -> Option<usize> {
|
||||
self.focused_index
|
||||
}
|
||||
|
||||
/// Get focused message
|
||||
pub fn focused_message(&self) -> Option<&ChatMessage> {
|
||||
self.focused_index
|
||||
.and_then(|i| self.messages.get(i))
|
||||
.map(|m| &m.message)
|
||||
}
|
||||
|
||||
/// Update scroll position before rendering
|
||||
pub fn update_scroll(&mut self, area: Rect) {
|
||||
self.total_lines = self.count_total_lines(area);
|
||||
|
||||
if self.auto_scroll {
|
||||
let visible_height = area.height as usize;
|
||||
let max_scroll = self.total_lines.saturating_sub(visible_height);
|
||||
self.scroll_offset = max_scroll;
|
||||
} else {
|
||||
let visible_height = area.height as usize;
|
||||
let max_scroll = self.total_lines.saturating_sub(visible_height);
|
||||
self.scroll_offset = self.scroll_offset.min(max_scroll);
|
||||
}
|
||||
}
|
||||
|
||||
/// Count total lines for scroll calculation
|
||||
fn count_total_lines(&self, area: Rect) -> usize {
|
||||
let mut line_count = 0;
|
||||
let wrap_width = area.width.saturating_sub(4) as usize;
|
||||
|
||||
for msg in &self.messages {
|
||||
line_count += match &msg.message {
|
||||
ChatMessage::User(content) => {
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
wrapped.len() + 1 // +1 for spacing
|
||||
}
|
||||
ChatMessage::Assistant(content) => {
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
wrapped.len() + 1
|
||||
}
|
||||
ChatMessage::ToolCall { .. } => 2,
|
||||
ChatMessage::ToolResult { .. } => 2,
|
||||
ChatMessage::System(_) => 1,
|
||||
};
|
||||
}
|
||||
|
||||
line_count
|
||||
}
|
||||
|
||||
/// Render the borderless chat panel
|
||||
///
|
||||
/// Message display format (no symbols, clean typography):
|
||||
/// - Role: bold, appropriate color
|
||||
/// - Timestamp: dim, same line as role
|
||||
/// - Content: 2-space indent, normal weight
|
||||
/// - Blank line between messages
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let mut text_lines = Vec::new();
|
||||
let wrap_width = area.width.saturating_sub(4) as usize;
|
||||
|
||||
for (idx, display_msg) in self.messages.iter().enumerate() {
|
||||
let is_focused = self.focused_index == Some(idx);
|
||||
let is_last = idx == self.messages.len() - 1;
|
||||
|
||||
match &display_msg.message {
|
||||
ChatMessage::User(content) => {
|
||||
// Role line: "You" bold + timestamp dim
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled("You", self.theme.user_message),
|
||||
Span::styled(
|
||||
format!(" {}", display_msg.timestamp),
|
||||
self.theme.timestamp,
|
||||
),
|
||||
]));
|
||||
|
||||
// Message content with 2-space indent
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
for line in wrapped {
|
||||
let style = if is_focused {
|
||||
self.theme.user_message.add_modifier(Modifier::REVERSED)
|
||||
} else {
|
||||
self.theme.user_message.remove_modifier(Modifier::BOLD)
|
||||
};
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
format!(" {}", line),
|
||||
style,
|
||||
)));
|
||||
}
|
||||
|
||||
// Focus hints
|
||||
if is_focused {
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
" [y]copy [e]edit [r]retry",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::Assistant(content) => {
|
||||
// Role line: streaming indicator (if active) + "Assistant" bold + timestamp
|
||||
let mut role_spans = vec![Span::styled(" ", Style::default())];
|
||||
|
||||
// Streaming indicator (subtle, no symbol)
|
||||
if is_last && self.is_streaming {
|
||||
role_spans.push(Span::styled(
|
||||
"... ",
|
||||
Style::default().fg(self.theme.palette.success),
|
||||
));
|
||||
}
|
||||
|
||||
role_spans.push(Span::styled(
|
||||
"Assistant",
|
||||
self.theme.assistant_message.add_modifier(Modifier::BOLD),
|
||||
));
|
||||
|
||||
role_spans.push(Span::styled(
|
||||
format!(" {}", display_msg.timestamp),
|
||||
self.theme.timestamp,
|
||||
));
|
||||
|
||||
text_lines.push(Line::from(role_spans));
|
||||
|
||||
// Content
|
||||
let wrapped = textwrap::wrap(content, wrap_width);
|
||||
for line in wrapped {
|
||||
let style = if is_focused {
|
||||
self.theme.assistant_message.add_modifier(Modifier::REVERSED)
|
||||
} else {
|
||||
self.theme.assistant_message
|
||||
};
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
format!(" {}", line),
|
||||
style,
|
||||
)));
|
||||
}
|
||||
|
||||
// Focus hints
|
||||
if is_focused {
|
||||
text_lines.push(Line::from(Span::styled(
|
||||
" [y]copy [r]retry",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::ToolCall { name, args } => {
|
||||
// Tool calls: name in tool color, args dimmed
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(format!("{} ", name), self.theme.tool_call),
|
||||
Span::styled(
|
||||
truncate_str(args, 60),
|
||||
self.theme.tool_call.add_modifier(Modifier::DIM),
|
||||
),
|
||||
]));
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::ToolResult { success, output } => {
|
||||
// Tool results: status prefix + output
|
||||
let (prefix, style) = if *success {
|
||||
("ok ", self.theme.tool_result_success)
|
||||
} else {
|
||||
("err ", self.theme.tool_result_error)
|
||||
};
|
||||
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(prefix, style),
|
||||
Span::styled(
|
||||
truncate_str(output, 100),
|
||||
style.remove_modifier(Modifier::BOLD),
|
||||
),
|
||||
]));
|
||||
text_lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
ChatMessage::System(content) => {
|
||||
// System messages: just dim text, no prefix
|
||||
text_lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled(content.to_string(), self.theme.system_message),
|
||||
]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let text = Text::from(text_lines);
|
||||
let paragraph = Paragraph::new(text).scroll((self.scroll_offset as u16, 0));
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
|
||||
// Render scrollbar if needed
|
||||
if self.total_lines > area.height as usize {
|
||||
let scrollbar = Scrollbar::default()
|
||||
.orientation(ScrollbarOrientation::VerticalRight)
|
||||
.begin_symbol(None)
|
||||
.end_symbol(None)
|
||||
.track_symbol(Some(" "))
|
||||
.thumb_symbol("│")
|
||||
.style(self.theme.status_dim);
|
||||
|
||||
let mut scrollbar_state = ScrollbarState::default()
|
||||
.content_length(self.total_lines)
|
||||
.position(self.scroll_offset);
|
||||
|
||||
frame.render_stateful_widget(scrollbar, area, &mut scrollbar_state);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get messages
|
||||
pub fn messages(&self) -> &[DisplayMessage] {
|
||||
&self.messages
|
||||
}
|
||||
|
||||
/// Clear all messages
|
||||
pub fn clear(&mut self) {
|
||||
self.messages.clear();
|
||||
self.scroll_offset = 0;
|
||||
self.focused_index = None;
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
}
|
||||
|
||||
/// Truncate a string to max length with ellipsis
|
||||
fn truncate_str(s: &str, max_len: usize) -> String {
|
||||
if s.len() <= max_len {
|
||||
s.to_string()
|
||||
} else {
|
||||
format!("{}...", &s[..max_len.saturating_sub(3)])
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chat_panel_add_message() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.add_message(ChatMessage::User("Hello".to_string()));
|
||||
panel.add_message(ChatMessage::Assistant("Hi there!".to_string()));
|
||||
|
||||
assert_eq!(panel.messages().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_to_assistant() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.append_to_assistant("Hello");
|
||||
panel.append_to_assistant(" world");
|
||||
|
||||
assert_eq!(panel.messages().len(), 1);
|
||||
if let ChatMessage::Assistant(content) = &panel.messages()[0].message {
|
||||
assert_eq!(content, "Hello world");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_focus_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = ChatPanel::new(theme);
|
||||
|
||||
panel.add_message(ChatMessage::User("1".to_string()));
|
||||
panel.add_message(ChatMessage::User("2".to_string()));
|
||||
panel.add_message(ChatMessage::User("3".to_string()));
|
||||
|
||||
assert_eq!(panel.focused_index(), None);
|
||||
|
||||
panel.focus_next();
|
||||
assert_eq!(panel.focused_index(), Some(0));
|
||||
|
||||
panel.focus_next();
|
||||
assert_eq!(panel.focused_index(), Some(1));
|
||||
|
||||
panel.focus_previous();
|
||||
assert_eq!(panel.focused_index(), Some(0));
|
||||
}
|
||||
}
|
||||
322
crates/app/ui/src/components/command_help.rs
Normal file
322
crates/app/ui/src/components/command_help.rs
Normal file
@@ -0,0 +1,322 @@
|
||||
//! Command help overlay component
|
||||
//!
|
||||
//! Modal overlay that displays available commands in a structured format.
|
||||
//! Shown when user types `/help` or `?`. Supports scrolling with j/k or arrows.
|
||||
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph, Scrollbar, ScrollbarOrientation, ScrollbarState},
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// A single command definition
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Command {
|
||||
pub name: &'static str,
|
||||
pub args: Option<&'static str>,
|
||||
pub description: &'static str,
|
||||
}
|
||||
|
||||
impl Command {
|
||||
pub const fn new(name: &'static str, description: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
args: None,
|
||||
description,
|
||||
}
|
||||
}
|
||||
|
||||
pub const fn with_args(name: &'static str, args: &'static str, description: &'static str) -> Self {
|
||||
Self {
|
||||
name,
|
||||
args: Some(args),
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in commands
|
||||
pub fn builtin_commands() -> Vec<Command> {
|
||||
vec![
|
||||
Command::new("help", "Show this help"),
|
||||
Command::new("status", "Current session info"),
|
||||
Command::with_args("model", "[name]", "Switch model"),
|
||||
Command::with_args("provider", "[name]", "Switch provider (ollama, anthropic, openai)"),
|
||||
Command::new("history", "Browse conversation history"),
|
||||
Command::new("checkpoint", "Save conversation state"),
|
||||
Command::new("checkpoints", "List saved checkpoints"),
|
||||
Command::with_args("rewind", "[id]", "Restore checkpoint"),
|
||||
Command::new("cost", "Show token usage"),
|
||||
Command::new("clear", "Clear conversation"),
|
||||
Command::new("compact", "Compact conversation context"),
|
||||
Command::new("permissions", "Show permission mode"),
|
||||
Command::new("themes", "List available themes"),
|
||||
Command::with_args("theme", "[name]", "Switch theme"),
|
||||
Command::new("exit", "Exit OWLEN"),
|
||||
]
|
||||
}
|
||||
|
||||
/// Command help overlay
|
||||
pub struct CommandHelp {
|
||||
commands: Vec<Command>,
|
||||
visible: bool,
|
||||
scroll_offset: usize,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl CommandHelp {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
commands: builtin_commands(),
|
||||
visible: false,
|
||||
scroll_offset: 0,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Show the help overlay
|
||||
pub fn show(&mut self) {
|
||||
self.visible = true;
|
||||
self.scroll_offset = 0; // Reset scroll when showing
|
||||
}
|
||||
|
||||
/// Hide the help overlay
|
||||
pub fn hide(&mut self) {
|
||||
self.visible = false;
|
||||
}
|
||||
|
||||
/// Check if visible
|
||||
pub fn is_visible(&self) -> bool {
|
||||
self.visible
|
||||
}
|
||||
|
||||
/// Toggle visibility
|
||||
pub fn toggle(&mut self) {
|
||||
self.visible = !self.visible;
|
||||
if self.visible {
|
||||
self.scroll_offset = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scroll up by amount
|
||||
fn scroll_up(&mut self, amount: usize) {
|
||||
self.scroll_offset = self.scroll_offset.saturating_sub(amount);
|
||||
}
|
||||
|
||||
/// Scroll down by amount, respecting max
|
||||
fn scroll_down(&mut self, amount: usize, max_scroll: usize) {
|
||||
self.scroll_offset = (self.scroll_offset + amount).min(max_scroll);
|
||||
}
|
||||
|
||||
/// Handle key input, returns true if overlay handled the key
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> bool {
|
||||
if !self.visible {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Calculate max scroll (commands + padding lines - visible area)
|
||||
let total_lines = self.commands.len() + 3; // +3 for padding and footer
|
||||
let max_scroll = total_lines.saturating_sub(10); // Assume ~10 visible lines
|
||||
|
||||
match key.code {
|
||||
KeyCode::Esc | KeyCode::Char('q') | KeyCode::Char('?') => {
|
||||
self.hide();
|
||||
true
|
||||
}
|
||||
// Scroll navigation
|
||||
KeyCode::Up | KeyCode::Char('k') => {
|
||||
self.scroll_up(1);
|
||||
true
|
||||
}
|
||||
KeyCode::Down | KeyCode::Char('j') => {
|
||||
self.scroll_down(1, max_scroll);
|
||||
true
|
||||
}
|
||||
KeyCode::PageUp | KeyCode::Char('u') => {
|
||||
self.scroll_up(5);
|
||||
true
|
||||
}
|
||||
KeyCode::PageDown | KeyCode::Char('d') => {
|
||||
self.scroll_down(5, max_scroll);
|
||||
true
|
||||
}
|
||||
KeyCode::Home | KeyCode::Char('g') => {
|
||||
self.scroll_offset = 0;
|
||||
true
|
||||
}
|
||||
KeyCode::End | KeyCode::Char('G') => {
|
||||
self.scroll_offset = max_scroll;
|
||||
true
|
||||
}
|
||||
_ => true, // Consume all other keys while visible
|
||||
}
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Add plugin commands
|
||||
pub fn add_commands(&mut self, commands: Vec<Command>) {
|
||||
self.commands.extend(commands);
|
||||
}
|
||||
|
||||
/// Render the help overlay
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
if !self.visible {
|
||||
return;
|
||||
}
|
||||
|
||||
// Calculate overlay dimensions
|
||||
let width = (area.width as f32 * 0.7).min(65.0) as u16;
|
||||
let max_height = area.height.saturating_sub(4);
|
||||
let content_height = self.commands.len() as u16 + 4; // +4 for padding and footer
|
||||
let height = content_height.min(max_height).max(8);
|
||||
|
||||
// Center the overlay
|
||||
let x = (area.width.saturating_sub(width)) / 2;
|
||||
let y = (area.height.saturating_sub(height)) / 2;
|
||||
|
||||
let overlay_area = Rect::new(x, y, width, height);
|
||||
|
||||
// Clear the area behind the overlay
|
||||
frame.render_widget(Clear, overlay_area);
|
||||
|
||||
// Build content lines
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Empty line for padding
|
||||
lines.push(Line::from(""));
|
||||
|
||||
// Command list
|
||||
for cmd in &self.commands {
|
||||
let name_with_args = if let Some(args) = cmd.args {
|
||||
format!("/{} {}", cmd.name, args)
|
||||
} else {
|
||||
format!("/{}", cmd.name)
|
||||
};
|
||||
|
||||
// Calculate padding for alignment
|
||||
let name_width: usize = 22;
|
||||
let padding = name_width.saturating_sub(name_with_args.len());
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(" ", Style::default()),
|
||||
Span::styled("/", self.theme.cmd_slash),
|
||||
Span::styled(
|
||||
if let Some(args) = cmd.args {
|
||||
format!("{} {}", cmd.name, args)
|
||||
} else {
|
||||
cmd.name.to_string()
|
||||
},
|
||||
self.theme.cmd_name,
|
||||
),
|
||||
Span::raw(" ".repeat(padding)),
|
||||
Span::styled(cmd.description, self.theme.cmd_desc),
|
||||
]));
|
||||
}
|
||||
|
||||
// Empty line for padding
|
||||
lines.push(Line::from(""));
|
||||
|
||||
// Footer hint with scroll info
|
||||
let scroll_hint = if self.commands.len() > (height as usize - 4) {
|
||||
format!(" (scroll: j/k or ↑/↓)")
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(" Press ", self.theme.cmd_desc),
|
||||
Span::styled("Esc", self.theme.cmd_name),
|
||||
Span::styled(" to close", self.theme.cmd_desc),
|
||||
Span::styled(scroll_hint, self.theme.cmd_desc),
|
||||
]));
|
||||
|
||||
// Create the block with border
|
||||
let block = Block::default()
|
||||
.title(" Commands ")
|
||||
.title_style(self.theme.popup_title)
|
||||
.borders(Borders::ALL)
|
||||
.border_style(self.theme.popup_border)
|
||||
.style(self.theme.overlay_bg);
|
||||
|
||||
let paragraph = Paragraph::new(lines)
|
||||
.block(block)
|
||||
.scroll((self.scroll_offset as u16, 0));
|
||||
|
||||
frame.render_widget(paragraph, overlay_area);
|
||||
|
||||
// Render scrollbar if content exceeds visible area
|
||||
let visible_height = height.saturating_sub(2) as usize; // -2 for borders
|
||||
let total_lines = self.commands.len() + 3;
|
||||
if total_lines > visible_height {
|
||||
let scrollbar = Scrollbar::default()
|
||||
.orientation(ScrollbarOrientation::VerticalRight)
|
||||
.begin_symbol(None)
|
||||
.end_symbol(None)
|
||||
.track_symbol(Some(" "))
|
||||
.thumb_symbol("│")
|
||||
.style(self.theme.status_dim);
|
||||
|
||||
let mut scrollbar_state = ScrollbarState::default()
|
||||
.content_length(total_lines)
|
||||
.position(self.scroll_offset);
|
||||
|
||||
// Adjust scrollbar area to be inside the border
|
||||
let scrollbar_area = Rect::new(
|
||||
overlay_area.x + overlay_area.width - 2,
|
||||
overlay_area.y + 1,
|
||||
1,
|
||||
overlay_area.height.saturating_sub(2),
|
||||
);
|
||||
|
||||
frame.render_stateful_widget(scrollbar, scrollbar_area, &mut scrollbar_state);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_command_help_visibility() {
|
||||
let theme = Theme::default();
|
||||
let mut help = CommandHelp::new(theme);
|
||||
|
||||
assert!(!help.is_visible());
|
||||
help.show();
|
||||
assert!(help.is_visible());
|
||||
help.hide();
|
||||
assert!(!help.is_visible());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builtin_commands() {
|
||||
let commands = builtin_commands();
|
||||
assert!(!commands.is_empty());
|
||||
assert!(commands.iter().any(|c| c.name == "help"));
|
||||
assert!(commands.iter().any(|c| c.name == "provider"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scroll_navigation() {
|
||||
let theme = Theme::default();
|
||||
let mut help = CommandHelp::new(theme);
|
||||
help.show();
|
||||
|
||||
assert_eq!(help.scroll_offset, 0);
|
||||
help.scroll_down(3, 10);
|
||||
assert_eq!(help.scroll_offset, 3);
|
||||
help.scroll_up(1);
|
||||
assert_eq!(help.scroll_offset, 2);
|
||||
help.scroll_up(10); // Should clamp to 0
|
||||
assert_eq!(help.scroll_offset, 0);
|
||||
}
|
||||
}
|
||||
507
crates/app/ui/src/components/input_box.rs
Normal file
507
crates/app/ui/src/components/input_box.rs
Normal file
@@ -0,0 +1,507 @@
|
||||
//! Vim-modal input component
|
||||
//!
|
||||
//! Borderless input with vim-like modes (Normal, Insert, Command).
|
||||
//! Uses mode prefix instead of borders for visual indication.
|
||||
|
||||
use crate::theme::{Theme, VimMode};
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Input event from the input box
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum InputEvent {
|
||||
/// User submitted a message
|
||||
Message(String),
|
||||
/// User submitted a command (without / prefix)
|
||||
Command(String),
|
||||
/// Mode changed
|
||||
ModeChange(VimMode),
|
||||
/// Request to cancel current operation
|
||||
Cancel,
|
||||
/// Request to expand input (multiline)
|
||||
Expand,
|
||||
}
|
||||
|
||||
/// Vim-modal input box
|
||||
pub struct InputBox {
|
||||
input: String,
|
||||
cursor_position: usize,
|
||||
history: Vec<String>,
|
||||
history_index: usize,
|
||||
mode: VimMode,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl InputBox {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
input: String::new(),
|
||||
cursor_position: 0,
|
||||
history: Vec::new(),
|
||||
history_index: 0,
|
||||
mode: VimMode::Insert, // Start in insert mode for familiarity
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current vim mode
|
||||
pub fn mode(&self) -> VimMode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
/// Set vim mode
|
||||
pub fn set_mode(&mut self, mode: VimMode) {
|
||||
self.mode = mode;
|
||||
}
|
||||
|
||||
/// Handle key event, returns input event if action is needed
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match self.mode {
|
||||
VimMode::Normal => self.handle_normal_mode(key),
|
||||
VimMode::Insert => self.handle_insert_mode(key),
|
||||
VimMode::Command => self.handle_command_mode(key),
|
||||
VimMode::Visual => self.handle_visual_mode(key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in normal mode
|
||||
fn handle_normal_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
// Enter insert mode
|
||||
KeyCode::Char('i') => {
|
||||
self.mode = VimMode::Insert;
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('a') => {
|
||||
self.mode = VimMode::Insert;
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('I') => {
|
||||
self.mode = VimMode::Insert;
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
KeyCode::Char('A') => {
|
||||
self.mode = VimMode::Insert;
|
||||
self.cursor_position = self.input.len();
|
||||
Some(InputEvent::ModeChange(VimMode::Insert))
|
||||
}
|
||||
// Enter command mode
|
||||
KeyCode::Char(':') => {
|
||||
self.mode = VimMode::Command;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Command))
|
||||
}
|
||||
// Navigation
|
||||
KeyCode::Char('h') | KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Char('l') | KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('0') | KeyCode::Home => {
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
KeyCode::Char('$') | KeyCode::End => {
|
||||
self.cursor_position = self.input.len();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('w') => {
|
||||
// Jump to next word
|
||||
self.cursor_position = self.next_word_position();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('b') => {
|
||||
// Jump to previous word
|
||||
self.cursor_position = self.prev_word_position();
|
||||
None
|
||||
}
|
||||
// Editing
|
||||
KeyCode::Char('x') => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.input.remove(self.cursor_position);
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('d') => {
|
||||
// Delete line (dd would require tracking, simplify to clear)
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
// History
|
||||
KeyCode::Char('k') | KeyCode::Up => {
|
||||
self.history_prev();
|
||||
None
|
||||
}
|
||||
KeyCode::Char('j') | KeyCode::Down => {
|
||||
self.history_next();
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in insert mode
|
||||
fn handle_insert_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
// Move cursor back when exiting insert mode (vim behavior)
|
||||
if self.cursor_position > 0 {
|
||||
self.cursor_position -= 1;
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let message = self.input.clone();
|
||||
if !message.trim().is_empty() {
|
||||
self.history.push(message.clone());
|
||||
self.history_index = self.history.len();
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
return Some(InputEvent::Message(message));
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Char('e') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(InputEvent::Expand)
|
||||
}
|
||||
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(InputEvent::Cancel)
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
self.input.insert(self.cursor_position, c);
|
||||
self.cursor_position += 1;
|
||||
None
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
if self.cursor_position > 0 {
|
||||
self.input.remove(self.cursor_position - 1);
|
||||
self.cursor_position -= 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Delete => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.input.remove(self.cursor_position);
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Home => {
|
||||
self.cursor_position = 0;
|
||||
None
|
||||
}
|
||||
KeyCode::End => {
|
||||
self.cursor_position = self.input.len();
|
||||
None
|
||||
}
|
||||
KeyCode::Up => {
|
||||
self.history_prev();
|
||||
None
|
||||
}
|
||||
KeyCode::Down => {
|
||||
self.history_next();
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in command mode
|
||||
fn handle_command_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
let command = self.input.clone();
|
||||
self.mode = VimMode::Normal;
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
if !command.trim().is_empty() {
|
||||
return Some(InputEvent::Command(command));
|
||||
}
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
KeyCode::Char(c) => {
|
||||
self.input.insert(self.cursor_position, c);
|
||||
self.cursor_position += 1;
|
||||
None
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
if self.cursor_position > 0 {
|
||||
self.input.remove(self.cursor_position - 1);
|
||||
self.cursor_position -= 1;
|
||||
} else {
|
||||
// Empty command, exit to normal mode
|
||||
self.mode = VimMode::Normal;
|
||||
return Some(InputEvent::ModeChange(VimMode::Normal));
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Left => {
|
||||
self.cursor_position = self.cursor_position.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Right => {
|
||||
if self.cursor_position < self.input.len() {
|
||||
self.cursor_position += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle keys in visual mode (simplified)
|
||||
fn handle_visual_mode(&mut self, key: KeyEvent) -> Option<InputEvent> {
|
||||
match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = VimMode::Normal;
|
||||
Some(InputEvent::ModeChange(VimMode::Normal))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// History navigation - previous
|
||||
fn history_prev(&mut self) {
|
||||
if !self.history.is_empty() && self.history_index > 0 {
|
||||
self.history_index -= 1;
|
||||
self.input = self.history[self.history_index].clone();
|
||||
self.cursor_position = self.input.len();
|
||||
}
|
||||
}
|
||||
|
||||
/// History navigation - next
|
||||
fn history_next(&mut self) {
|
||||
if self.history_index < self.history.len().saturating_sub(1) {
|
||||
self.history_index += 1;
|
||||
self.input = self.history[self.history_index].clone();
|
||||
self.cursor_position = self.input.len();
|
||||
} else if self.history_index < self.history.len() {
|
||||
self.history_index = self.history.len();
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Find next word position
|
||||
fn next_word_position(&self) -> usize {
|
||||
let bytes = self.input.as_bytes();
|
||||
let mut pos = self.cursor_position;
|
||||
|
||||
// Skip current word
|
||||
while pos < bytes.len() && !bytes[pos].is_ascii_whitespace() {
|
||||
pos += 1;
|
||||
}
|
||||
// Skip whitespace
|
||||
while pos < bytes.len() && bytes[pos].is_ascii_whitespace() {
|
||||
pos += 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Find previous word position
|
||||
fn prev_word_position(&self) -> usize {
|
||||
let bytes = self.input.as_bytes();
|
||||
let mut pos = self.cursor_position.saturating_sub(1);
|
||||
|
||||
// Skip whitespace
|
||||
while pos > 0 && bytes[pos].is_ascii_whitespace() {
|
||||
pos -= 1;
|
||||
}
|
||||
// Skip to start of word
|
||||
while pos > 0 && !bytes[pos - 1].is_ascii_whitespace() {
|
||||
pos -= 1;
|
||||
}
|
||||
pos
|
||||
}
|
||||
|
||||
/// Render the borderless input (single line)
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let is_empty = self.input.is_empty();
|
||||
let symbols = &self.theme.symbols;
|
||||
|
||||
// Mode-specific prefix
|
||||
let prefix = match self.mode {
|
||||
VimMode::Normal => Span::styled(
|
||||
format!("{} ", symbols.mode_normal),
|
||||
self.theme.status_dim,
|
||||
),
|
||||
VimMode::Insert => Span::styled(
|
||||
format!("{} ", symbols.user_prefix),
|
||||
self.theme.input_prefix,
|
||||
),
|
||||
VimMode::Command => Span::styled(
|
||||
": ",
|
||||
self.theme.input_prefix,
|
||||
),
|
||||
VimMode::Visual => Span::styled(
|
||||
format!("{} ", symbols.mode_visual),
|
||||
self.theme.status_accent,
|
||||
),
|
||||
};
|
||||
|
||||
// Cursor position handling
|
||||
let (text_before, cursor_char, text_after) = if self.cursor_position < self.input.len() {
|
||||
let before = &self.input[..self.cursor_position];
|
||||
let cursor = &self.input[self.cursor_position..self.cursor_position + 1];
|
||||
let after = &self.input[self.cursor_position + 1..];
|
||||
(before, cursor, after)
|
||||
} else {
|
||||
(&self.input[..], " ", "")
|
||||
};
|
||||
|
||||
let line = if is_empty && self.mode == VimMode::Insert {
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled("▊", self.theme.input_prefix),
|
||||
Span::styled(" Type message...", self.theme.input_placeholder),
|
||||
])
|
||||
} else if is_empty && self.mode == VimMode::Command {
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled("▊", self.theme.input_prefix),
|
||||
])
|
||||
} else {
|
||||
// Build cursor span with appropriate styling
|
||||
let cursor_style = if self.mode == VimMode::Normal {
|
||||
Style::default()
|
||||
.bg(self.theme.palette.fg)
|
||||
.fg(self.theme.palette.bg)
|
||||
} else {
|
||||
self.theme.input_prefix
|
||||
};
|
||||
|
||||
let cursor_span = if self.mode == VimMode::Normal && !is_empty {
|
||||
Span::styled(cursor_char.to_string(), cursor_style)
|
||||
} else {
|
||||
Span::styled("▊", self.theme.input_prefix)
|
||||
};
|
||||
|
||||
Line::from(vec![
|
||||
Span::raw(" "),
|
||||
prefix,
|
||||
Span::styled(text_before.to_string(), self.theme.input_text),
|
||||
cursor_span,
|
||||
Span::styled(text_after.to_string(), self.theme.input_text),
|
||||
])
|
||||
};
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Clear input
|
||||
pub fn clear(&mut self) {
|
||||
self.input.clear();
|
||||
self.cursor_position = 0;
|
||||
}
|
||||
|
||||
/// Get current input text
|
||||
pub fn text(&self) -> &str {
|
||||
&self.input
|
||||
}
|
||||
|
||||
/// Set input text
|
||||
pub fn set_text(&mut self, text: String) {
|
||||
self.input = text;
|
||||
self.cursor_position = self.input.len();
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_mode_transitions() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
// Start in insert mode
|
||||
assert_eq!(input.mode(), VimMode::Insert);
|
||||
|
||||
// Escape to normal mode
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Esc));
|
||||
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Normal))));
|
||||
assert_eq!(input.mode(), VimMode::Normal);
|
||||
|
||||
// 'i' to insert mode
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
assert!(matches!(event, Some(InputEvent::ModeChange(VimMode::Insert))));
|
||||
assert_eq!(input.mode(), VimMode::Insert);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_insert_text() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('h')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
|
||||
assert_eq!(input.text(), "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_command_mode() {
|
||||
let theme = Theme::default();
|
||||
let mut input = InputBox::new(theme);
|
||||
|
||||
// Escape to normal, then : to command
|
||||
input.handle_key(KeyEvent::from(KeyCode::Esc));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char(':')));
|
||||
|
||||
assert_eq!(input.mode(), VimMode::Command);
|
||||
|
||||
// Type command
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('q')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('u')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('i')));
|
||||
input.handle_key(KeyEvent::from(KeyCode::Char('t')));
|
||||
|
||||
assert_eq!(input.text(), "quit");
|
||||
|
||||
// Submit command
|
||||
let event = input.handle_key(KeyEvent::from(KeyCode::Enter));
|
||||
assert!(matches!(event, Some(InputEvent::Command(cmd)) if cmd == "quit"));
|
||||
}
|
||||
}
|
||||
19
crates/app/ui/src/components/mod.rs
Normal file
19
crates/app/ui/src/components/mod.rs
Normal file
@@ -0,0 +1,19 @@
|
||||
//! TUI components for the borderless multi-provider design
|
||||
|
||||
mod autocomplete;
|
||||
mod chat_panel;
|
||||
mod command_help;
|
||||
mod input_box;
|
||||
mod permission_popup;
|
||||
mod provider_tabs;
|
||||
mod status_bar;
|
||||
mod todo_panel;
|
||||
|
||||
pub use autocomplete::{Autocomplete, AutocompleteOption, AutocompleteResult};
|
||||
pub use chat_panel::{ChatMessage, ChatPanel, DisplayMessage};
|
||||
pub use command_help::{Command, CommandHelp};
|
||||
pub use input_box::{InputBox, InputEvent};
|
||||
pub use permission_popup::{PermissionOption, PermissionPopup};
|
||||
pub use provider_tabs::ProviderTabs;
|
||||
pub use status_bar::{AppState, StatusBar};
|
||||
pub use todo_panel::TodoPanel;
|
||||
196
crates/app/ui/src/components/permission_popup.rs
Normal file
196
crates/app/ui/src/components/permission_popup.rs
Normal file
@@ -0,0 +1,196 @@
|
||||
use crate::theme::Theme;
|
||||
use crossterm::event::{KeyCode, KeyEvent};
|
||||
use permissions::PermissionDecision;
|
||||
use ratatui::{
|
||||
layout::{Constraint, Direction, Layout, Rect},
|
||||
style::{Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Clear, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum PermissionOption {
|
||||
AllowOnce,
|
||||
AlwaysAllow,
|
||||
Deny,
|
||||
Explain,
|
||||
}
|
||||
|
||||
pub struct PermissionPopup {
|
||||
tool: String,
|
||||
context: Option<String>,
|
||||
selected: usize,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl PermissionPopup {
|
||||
pub fn new(tool: String, context: Option<String>, theme: Theme) -> Self {
|
||||
Self {
|
||||
tool,
|
||||
context,
|
||||
selected: 0,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_key(&mut self, key: KeyEvent) -> Option<PermissionOption> {
|
||||
match key.code {
|
||||
KeyCode::Char('a') => Some(PermissionOption::AllowOnce),
|
||||
KeyCode::Char('A') => Some(PermissionOption::AlwaysAllow),
|
||||
KeyCode::Char('d') => Some(PermissionOption::Deny),
|
||||
KeyCode::Char('?') => Some(PermissionOption::Explain),
|
||||
KeyCode::Up => {
|
||||
self.selected = self.selected.saturating_sub(1);
|
||||
None
|
||||
}
|
||||
KeyCode::Down => {
|
||||
if self.selected < 3 {
|
||||
self.selected += 1;
|
||||
}
|
||||
None
|
||||
}
|
||||
KeyCode::Enter => match self.selected {
|
||||
0 => Some(PermissionOption::AllowOnce),
|
||||
1 => Some(PermissionOption::AlwaysAllow),
|
||||
2 => Some(PermissionOption::Deny),
|
||||
3 => Some(PermissionOption::Explain),
|
||||
_ => None,
|
||||
},
|
||||
KeyCode::Esc => Some(PermissionOption::Deny),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
// Center the popup
|
||||
let popup_area = crate::layout::AppLayout::center_popup(area, 64, 14);
|
||||
|
||||
// Clear the area behind the popup
|
||||
frame.render_widget(Clear, popup_area);
|
||||
|
||||
// Render popup with styled border
|
||||
let block = Block::default()
|
||||
.borders(Borders::ALL)
|
||||
.border_style(self.theme.popup_border)
|
||||
.style(self.theme.popup_bg)
|
||||
.title(Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled("🔒", self.theme.popup_title),
|
||||
Span::raw(" "),
|
||||
Span::styled("Permission Required", self.theme.popup_title),
|
||||
Span::raw(" "),
|
||||
]));
|
||||
|
||||
frame.render_widget(block, popup_area);
|
||||
|
||||
// Split popup into sections
|
||||
let inner = popup_area.inner(ratatui::layout::Margin {
|
||||
vertical: 1,
|
||||
horizontal: 2,
|
||||
});
|
||||
|
||||
let sections = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(2), // Tool name with box
|
||||
Constraint::Length(3), // Context (if any)
|
||||
Constraint::Length(1), // Separator
|
||||
Constraint::Length(1), // Option 1
|
||||
Constraint::Length(1), // Option 2
|
||||
Constraint::Length(1), // Option 3
|
||||
Constraint::Length(1), // Option 4
|
||||
Constraint::Length(1), // Help text
|
||||
])
|
||||
.split(inner);
|
||||
|
||||
// Tool name with highlight
|
||||
let tool_line = Line::from(vec![
|
||||
Span::styled("⚡ Tool: ", Style::default().fg(self.theme.palette.warning)),
|
||||
Span::styled(&self.tool, self.theme.popup_title),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(tool_line), sections[0]);
|
||||
|
||||
// Context with wrapping
|
||||
if let Some(ctx) = &self.context {
|
||||
let context_text = if ctx.len() > 100 {
|
||||
format!("{}...", &ctx[..100])
|
||||
} else {
|
||||
ctx.clone()
|
||||
};
|
||||
let context_lines = textwrap::wrap(&context_text, (sections[1].width - 2) as usize);
|
||||
let mut lines = vec![
|
||||
Line::from(vec![
|
||||
Span::styled("📝 Context: ", Style::default().fg(self.theme.palette.info)),
|
||||
])
|
||||
];
|
||||
for line in context_lines.iter().take(2) {
|
||||
lines.push(Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled(line.to_string(), Style::default().fg(self.theme.palette.fg_dim)),
|
||||
]));
|
||||
}
|
||||
frame.render_widget(Paragraph::new(lines), sections[1]);
|
||||
}
|
||||
|
||||
// Separator
|
||||
let separator = Line::styled(
|
||||
"─".repeat(sections[2].width as usize),
|
||||
Style::default().fg(self.theme.palette.divider_fg),
|
||||
);
|
||||
frame.render_widget(Paragraph::new(separator), sections[2]);
|
||||
|
||||
// Options with icons and colors
|
||||
let options = [
|
||||
("✓", " [a] Allow once", self.theme.palette.success, 0),
|
||||
("✓✓", " [A] Always allow", self.theme.palette.primary, 1),
|
||||
("✗", " [d] Deny", self.theme.palette.error, 2),
|
||||
("?", " [?] Explain", self.theme.palette.info, 3),
|
||||
];
|
||||
|
||||
for (icon, text, color, idx) in options.iter() {
|
||||
let (style, prefix) = if self.selected == *idx {
|
||||
(
|
||||
self.theme.selected,
|
||||
"▶ "
|
||||
)
|
||||
} else {
|
||||
(
|
||||
Style::default().fg(*color),
|
||||
" "
|
||||
)
|
||||
};
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::styled(prefix, style),
|
||||
Span::styled(*icon, style),
|
||||
Span::styled(*text, style),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(line), sections[3 + idx]);
|
||||
}
|
||||
|
||||
// Help text at bottom
|
||||
let help_line = Line::from(vec![
|
||||
Span::styled(
|
||||
"↑↓ Navigate Enter to select Esc to deny",
|
||||
Style::default().fg(self.theme.palette.fg_dim).add_modifier(Modifier::ITALIC),
|
||||
),
|
||||
]);
|
||||
frame.render_widget(Paragraph::new(help_line), sections[7]);
|
||||
}
|
||||
}
|
||||
|
||||
impl PermissionOption {
|
||||
pub fn to_decision(&self) -> Option<PermissionDecision> {
|
||||
match self {
|
||||
PermissionOption::AllowOnce => Some(PermissionDecision::Allow),
|
||||
PermissionOption::AlwaysAllow => Some(PermissionDecision::Allow),
|
||||
PermissionOption::Deny => Some(PermissionDecision::Deny),
|
||||
PermissionOption::Explain => None, // Special handling needed
|
||||
}
|
||||
}
|
||||
|
||||
pub fn should_persist(&self) -> bool {
|
||||
matches!(self, PermissionOption::AlwaysAllow)
|
||||
}
|
||||
}
|
||||
189
crates/app/ui/src/components/provider_tabs.rs
Normal file
189
crates/app/ui/src/components/provider_tabs.rs
Normal file
@@ -0,0 +1,189 @@
|
||||
//! Provider tabs component for multi-LLM support
|
||||
//!
|
||||
//! Displays horizontal tabs for switching between providers (Claude, Ollama, OpenAI)
|
||||
//! with icons and keybind hints.
|
||||
|
||||
use crate::theme::{Provider, Theme};
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Provider tab state and rendering
|
||||
pub struct ProviderTabs {
|
||||
active: Provider,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl ProviderTabs {
|
||||
/// Create new provider tabs with default provider
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
active: Provider::Ollama, // Default to Ollama (local)
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with specific active provider
|
||||
pub fn with_provider(provider: Provider, theme: Theme) -> Self {
|
||||
Self {
|
||||
active: provider,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the currently active provider
|
||||
pub fn active(&self) -> Provider {
|
||||
self.active
|
||||
}
|
||||
|
||||
/// Set the active provider
|
||||
pub fn set_active(&mut self, provider: Provider) {
|
||||
self.active = provider;
|
||||
}
|
||||
|
||||
/// Cycle to the next provider
|
||||
pub fn next(&mut self) {
|
||||
self.active = match self.active {
|
||||
Provider::Claude => Provider::Ollama,
|
||||
Provider::Ollama => Provider::OpenAI,
|
||||
Provider::OpenAI => Provider::Claude,
|
||||
};
|
||||
}
|
||||
|
||||
/// Cycle to the previous provider
|
||||
pub fn previous(&mut self) {
|
||||
self.active = match self.active {
|
||||
Provider::Claude => Provider::OpenAI,
|
||||
Provider::Ollama => Provider::Claude,
|
||||
Provider::OpenAI => Provider::Ollama,
|
||||
};
|
||||
}
|
||||
|
||||
/// Select provider by number (1, 2, 3)
|
||||
pub fn select_by_number(&mut self, num: u8) {
|
||||
self.active = match num {
|
||||
1 => Provider::Claude,
|
||||
2 => Provider::Ollama,
|
||||
3 => Provider::OpenAI,
|
||||
_ => self.active,
|
||||
};
|
||||
}
|
||||
|
||||
/// Update the theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Render the provider tabs (borderless)
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let mut spans = Vec::new();
|
||||
|
||||
// Add spacing at start
|
||||
spans.push(Span::raw(" "));
|
||||
|
||||
for (i, provider) in Provider::all().iter().enumerate() {
|
||||
let is_active = *provider == self.active;
|
||||
let icon = self.theme.provider_icon(*provider);
|
||||
let name = provider.name();
|
||||
let number = (i + 1).to_string();
|
||||
|
||||
// Keybind hint
|
||||
spans.push(Span::styled(
|
||||
format!("[{}] ", number),
|
||||
self.theme.status_dim,
|
||||
));
|
||||
|
||||
// Icon and name
|
||||
let style = if is_active {
|
||||
Style::default()
|
||||
.fg(self.theme.provider_color(*provider))
|
||||
.add_modifier(ratatui::style::Modifier::BOLD)
|
||||
} else {
|
||||
self.theme.tab_inactive
|
||||
};
|
||||
|
||||
spans.push(Span::styled(format!("{} ", icon), style));
|
||||
spans.push(Span::styled(name.to_string(), style));
|
||||
|
||||
// Separator between tabs (not after last)
|
||||
if i < Provider::all().len() - 1 {
|
||||
spans.push(Span::styled(
|
||||
format!(" {} ", self.theme.symbols.vertical_separator),
|
||||
self.theme.status_dim,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// Tab cycling hint on the right
|
||||
spans.push(Span::raw(" "));
|
||||
spans.push(Span::styled("[Tab] cycle", self.theme.status_dim));
|
||||
|
||||
let line = Line::from(spans);
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Render a compact version (just active provider)
|
||||
pub fn render_compact(&self, frame: &mut Frame, area: Rect) {
|
||||
let icon = self.theme.provider_icon(self.active);
|
||||
let name = self.active.name();
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::raw(" "),
|
||||
Span::styled(
|
||||
format!("{} {}", icon, name),
|
||||
Style::default()
|
||||
.fg(self.theme.provider_color(self.active))
|
||||
.add_modifier(ratatui::style::Modifier::BOLD),
|
||||
),
|
||||
]);
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_provider_cycling() {
|
||||
let theme = Theme::default();
|
||||
let mut tabs = ProviderTabs::new(theme);
|
||||
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::Claude);
|
||||
|
||||
tabs.next();
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_select_by_number() {
|
||||
let theme = Theme::default();
|
||||
let mut tabs = ProviderTabs::new(theme);
|
||||
|
||||
tabs.select_by_number(1);
|
||||
assert_eq!(tabs.active(), Provider::Claude);
|
||||
|
||||
tabs.select_by_number(2);
|
||||
assert_eq!(tabs.active(), Provider::Ollama);
|
||||
|
||||
tabs.select_by_number(3);
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
|
||||
// Invalid number should not change
|
||||
tabs.select_by_number(4);
|
||||
assert_eq!(tabs.active(), Provider::OpenAI);
|
||||
}
|
||||
}
|
||||
188
crates/app/ui/src/components/status_bar.rs
Normal file
188
crates/app/ui/src/components/status_bar.rs
Normal file
@@ -0,0 +1,188 @@
|
||||
//! Minimal status bar component
|
||||
//!
|
||||
//! Clean, readable status bar with essential info only.
|
||||
//! Format: ` Mode │ N msgs │ ~Nk tok │ state`
|
||||
|
||||
use crate::theme::{Provider, Theme, VimMode};
|
||||
use agent_core::SessionStats;
|
||||
use permissions::Mode;
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::Style,
|
||||
text::{Line, Span},
|
||||
widgets::Paragraph,
|
||||
Frame,
|
||||
};
|
||||
|
||||
/// Application state for status display
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AppState {
|
||||
Idle,
|
||||
Streaming,
|
||||
WaitingPermission,
|
||||
Error,
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn label(&self) -> &'static str {
|
||||
match self {
|
||||
AppState::Idle => "idle",
|
||||
AppState::Streaming => "streaming...",
|
||||
AppState::WaitingPermission => "waiting",
|
||||
AppState::Error => "error",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct StatusBar {
|
||||
provider: Provider,
|
||||
model: String,
|
||||
mode: Mode,
|
||||
vim_mode: VimMode,
|
||||
stats: SessionStats,
|
||||
last_tool: Option<String>,
|
||||
state: AppState,
|
||||
estimated_cost: f64,
|
||||
planning_mode: bool,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl StatusBar {
|
||||
pub fn new(model: String, mode: Mode, theme: Theme) -> Self {
|
||||
Self {
|
||||
provider: Provider::Ollama, // Default provider
|
||||
model,
|
||||
mode,
|
||||
vim_mode: VimMode::Insert,
|
||||
stats: SessionStats::new(),
|
||||
last_tool: None,
|
||||
state: AppState::Idle,
|
||||
estimated_cost: 0.0,
|
||||
planning_mode: false,
|
||||
theme,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the active provider
|
||||
pub fn set_provider(&mut self, provider: Provider) {
|
||||
self.provider = provider;
|
||||
}
|
||||
|
||||
/// Set the current model
|
||||
pub fn set_model(&mut self, model: String) {
|
||||
self.model = model;
|
||||
}
|
||||
|
||||
/// Update session stats
|
||||
pub fn update_stats(&mut self, stats: SessionStats) {
|
||||
self.stats = stats;
|
||||
}
|
||||
|
||||
/// Set the last used tool
|
||||
pub fn set_last_tool(&mut self, tool: String) {
|
||||
self.last_tool = Some(tool);
|
||||
}
|
||||
|
||||
/// Set application state
|
||||
pub fn set_state(&mut self, state: AppState) {
|
||||
self.state = state;
|
||||
}
|
||||
|
||||
/// Set vim mode for display
|
||||
pub fn set_vim_mode(&mut self, mode: VimMode) {
|
||||
self.vim_mode = mode;
|
||||
}
|
||||
|
||||
/// Add to estimated cost
|
||||
pub fn add_cost(&mut self, cost: f64) {
|
||||
self.estimated_cost += cost;
|
||||
}
|
||||
|
||||
/// Reset cost
|
||||
pub fn reset_cost(&mut self) {
|
||||
self.estimated_cost = 0.0;
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Set planning mode status
|
||||
pub fn set_planning_mode(&mut self, active: bool) {
|
||||
self.planning_mode = active;
|
||||
}
|
||||
|
||||
/// Render the minimal status bar
|
||||
///
|
||||
/// Format: ` Mode │ N msgs │ ~Nk tok │ state`
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect) {
|
||||
let sep = self.theme.symbols.vertical_separator;
|
||||
let sep_style = Style::default().fg(self.theme.palette.border);
|
||||
|
||||
// Permission mode
|
||||
let mode_str = if self.planning_mode {
|
||||
"PLAN"
|
||||
} else {
|
||||
match self.mode {
|
||||
Mode::Plan => "Plan",
|
||||
Mode::AcceptEdits => "Edit",
|
||||
Mode::Code => "Code",
|
||||
}
|
||||
};
|
||||
|
||||
// Format token count
|
||||
let tokens_str = if self.stats.estimated_tokens >= 1000 {
|
||||
format!("~{}k tok", self.stats.estimated_tokens / 1000)
|
||||
} else {
|
||||
format!("~{} tok", self.stats.estimated_tokens)
|
||||
};
|
||||
|
||||
// State style - only highlight non-idle states
|
||||
let state_style = match self.state {
|
||||
AppState::Idle => self.theme.status_dim,
|
||||
AppState::Streaming => Style::default().fg(self.theme.palette.success),
|
||||
AppState::WaitingPermission => Style::default().fg(self.theme.palette.warning),
|
||||
AppState::Error => Style::default().fg(self.theme.palette.error),
|
||||
};
|
||||
|
||||
// Build minimal status line
|
||||
let spans = vec![
|
||||
Span::styled(" ", self.theme.status_dim),
|
||||
// Mode
|
||||
Span::styled(mode_str, self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// Message count
|
||||
Span::styled(format!("{} msgs", self.stats.total_messages), self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// Token count
|
||||
Span::styled(&tokens_str, self.theme.status_dim),
|
||||
Span::styled(format!(" {} ", sep), sep_style),
|
||||
// State
|
||||
Span::styled(self.state.label(), state_style),
|
||||
];
|
||||
|
||||
let line = Line::from(spans);
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_status_bar_creation() {
|
||||
let theme = Theme::default();
|
||||
let status_bar = StatusBar::new("gpt-4".to_string(), Mode::Plan, theme);
|
||||
assert_eq!(status_bar.model, "gpt-4");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_app_state_display() {
|
||||
assert_eq!(AppState::Idle.label(), "idle");
|
||||
assert_eq!(AppState::Streaming.label(), "streaming...");
|
||||
assert_eq!(AppState::Error.label(), "error");
|
||||
}
|
||||
}
|
||||
200
crates/app/ui/src/components/todo_panel.rs
Normal file
200
crates/app/ui/src/components/todo_panel.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
//! Todo panel component for displaying task list
|
||||
//!
|
||||
//! Shows the current todo list with status indicators and progress.
|
||||
|
||||
use ratatui::{
|
||||
layout::Rect,
|
||||
style::{Color, Modifier, Style},
|
||||
text::{Line, Span},
|
||||
widgets::{Block, Borders, Paragraph},
|
||||
Frame,
|
||||
};
|
||||
use tools_todo::{Todo, TodoList, TodoStatus};
|
||||
|
||||
use crate::theme::Theme;
|
||||
|
||||
/// Todo panel component
|
||||
pub struct TodoPanel {
|
||||
theme: Theme,
|
||||
collapsed: bool,
|
||||
}
|
||||
|
||||
impl TodoPanel {
|
||||
pub fn new(theme: Theme) -> Self {
|
||||
Self {
|
||||
theme,
|
||||
collapsed: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Toggle collapsed state
|
||||
pub fn toggle(&mut self) {
|
||||
self.collapsed = !self.collapsed;
|
||||
}
|
||||
|
||||
/// Check if collapsed
|
||||
pub fn is_collapsed(&self) -> bool {
|
||||
self.collapsed
|
||||
}
|
||||
|
||||
/// Update theme
|
||||
pub fn set_theme(&mut self, theme: Theme) {
|
||||
self.theme = theme;
|
||||
}
|
||||
|
||||
/// Get the minimum height needed for the panel
|
||||
pub fn min_height(&self) -> u16 {
|
||||
if self.collapsed {
|
||||
1
|
||||
} else {
|
||||
5
|
||||
}
|
||||
}
|
||||
|
||||
/// Render the todo panel
|
||||
pub fn render(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
if self.collapsed {
|
||||
self.render_collapsed(frame, area, todos);
|
||||
} else {
|
||||
self.render_expanded(frame, area, todos);
|
||||
}
|
||||
}
|
||||
|
||||
/// Render collapsed view (single line summary)
|
||||
fn render_collapsed(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
let items = todos.read();
|
||||
let completed = items.iter().filter(|t| t.status == TodoStatus::Completed).count();
|
||||
let in_progress = items.iter().filter(|t| t.status == TodoStatus::InProgress).count();
|
||||
let pending = items.iter().filter(|t| t.status == TodoStatus::Pending).count();
|
||||
|
||||
let summary = if items.is_empty() {
|
||||
"No tasks".to_string()
|
||||
} else {
|
||||
format!(
|
||||
"{} {} / {} {} / {} {}",
|
||||
self.theme.symbols.check, completed,
|
||||
self.theme.symbols.streaming, in_progress,
|
||||
self.theme.symbols.bullet, pending
|
||||
)
|
||||
};
|
||||
|
||||
let line = Line::from(vec![
|
||||
Span::styled("Tasks: ", self.theme.status_bar),
|
||||
Span::styled(summary, self.theme.status_dim),
|
||||
Span::styled(" [t to expand]", self.theme.status_dim),
|
||||
]);
|
||||
|
||||
let paragraph = Paragraph::new(line);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
/// Render expanded view with task list
|
||||
fn render_expanded(&self, frame: &mut Frame, area: Rect, todos: &TodoList) {
|
||||
let items = todos.read();
|
||||
|
||||
let mut lines: Vec<Line> = Vec::new();
|
||||
|
||||
// Header
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled("Tasks", Style::default().add_modifier(Modifier::BOLD)),
|
||||
Span::styled(" [t to collapse]", self.theme.status_dim),
|
||||
]));
|
||||
|
||||
if items.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
" No active tasks",
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
} else {
|
||||
// Show tasks (limit to available space)
|
||||
let max_items = (area.height as usize).saturating_sub(2);
|
||||
let display_items: Vec<&Todo> = items.iter().take(max_items).collect();
|
||||
|
||||
for item in display_items {
|
||||
let (icon, style) = match item.status {
|
||||
TodoStatus::Completed => (
|
||||
self.theme.symbols.check,
|
||||
Style::default().fg(Color::Green),
|
||||
),
|
||||
TodoStatus::InProgress => (
|
||||
self.theme.symbols.streaming,
|
||||
Style::default().fg(Color::Yellow),
|
||||
),
|
||||
TodoStatus::Pending => (
|
||||
self.theme.symbols.bullet,
|
||||
self.theme.status_dim,
|
||||
),
|
||||
};
|
||||
|
||||
// Use active form for in-progress, content for others
|
||||
let text = if item.status == TodoStatus::InProgress {
|
||||
&item.active_form
|
||||
} else {
|
||||
&item.content
|
||||
};
|
||||
|
||||
// Truncate if too long
|
||||
let max_width = area.width.saturating_sub(6) as usize;
|
||||
let display_text = if text.len() > max_width {
|
||||
format!("{}...", &text[..max_width.saturating_sub(3)])
|
||||
} else {
|
||||
text.clone()
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(format!(" {} ", icon), style),
|
||||
Span::styled(display_text, style),
|
||||
]));
|
||||
}
|
||||
|
||||
// Show overflow indicator if needed
|
||||
if items.len() > max_items {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!(" ... and {} more", items.len() - max_items),
|
||||
self.theme.status_dim,
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let block = Block::default()
|
||||
.borders(Borders::TOP)
|
||||
.border_style(self.theme.status_dim);
|
||||
|
||||
let paragraph = Paragraph::new(lines).block(block);
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_todo_panel_creation() {
|
||||
let theme = Theme::default();
|
||||
let panel = TodoPanel::new(theme);
|
||||
assert!(!panel.is_collapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_todo_panel_toggle() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = TodoPanel::new(theme);
|
||||
|
||||
assert!(!panel.is_collapsed());
|
||||
panel.toggle();
|
||||
assert!(panel.is_collapsed());
|
||||
panel.toggle();
|
||||
assert!(!panel.is_collapsed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_min_height() {
|
||||
let theme = Theme::default();
|
||||
let mut panel = TodoPanel::new(theme);
|
||||
|
||||
assert_eq!(panel.min_height(), 5);
|
||||
panel.toggle();
|
||||
assert_eq!(panel.min_height(), 1);
|
||||
}
|
||||
}
|
||||
53
crates/app/ui/src/events.rs
Normal file
53
crates/app/ui/src/events.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use crossterm::event::{KeyCode, KeyEvent, KeyModifiers};
|
||||
use serde_json::Value;
|
||||
|
||||
/// Application events that drive the TUI
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AppEvent {
|
||||
/// User input from keyboard
|
||||
Input(KeyEvent),
|
||||
/// User submitted a message
|
||||
UserMessage(String),
|
||||
/// LLM streaming started
|
||||
StreamStart,
|
||||
/// LLM response chunk (streaming)
|
||||
LlmChunk(String),
|
||||
/// LLM streaming completed
|
||||
StreamEnd { response: String },
|
||||
/// LLM streaming error
|
||||
StreamError(String),
|
||||
/// Tool call started
|
||||
ToolCall { name: String, args: Value },
|
||||
/// Tool execution result
|
||||
ToolResult { success: bool, output: String },
|
||||
/// Permission request from agent
|
||||
PermissionRequest {
|
||||
tool: String,
|
||||
context: Option<String>,
|
||||
},
|
||||
/// Session statistics updated
|
||||
StatusUpdate(agent_core::SessionStats),
|
||||
/// Terminal was resized
|
||||
Resize { width: u16, height: u16 },
|
||||
/// Mouse scroll up
|
||||
ScrollUp,
|
||||
/// Mouse scroll down
|
||||
ScrollDown,
|
||||
/// Toggle the todo panel
|
||||
ToggleTodo,
|
||||
/// Application should quit
|
||||
Quit,
|
||||
}
|
||||
|
||||
/// Process keyboard input into app events
|
||||
pub fn handle_key_event(key: KeyEvent) -> Option<AppEvent> {
|
||||
match key.code {
|
||||
KeyCode::Char('c') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(AppEvent::Quit)
|
||||
}
|
||||
KeyCode::Char('t') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
Some(AppEvent::ToggleTodo)
|
||||
}
|
||||
_ => Some(AppEvent::Input(key)),
|
||||
}
|
||||
}
|
||||
532
crates/app/ui/src/formatting.rs
Normal file
532
crates/app/ui/src/formatting.rs
Normal file
@@ -0,0 +1,532 @@
|
||||
//! Output formatting with markdown parsing and syntax highlighting
|
||||
//!
|
||||
//! This module provides rich text rendering for the TUI, converting markdown
|
||||
//! content into styled ratatui spans with proper syntax highlighting for code blocks.
|
||||
|
||||
use pulldown_cmark::{CodeBlockKind, Event, Parser, Tag, TagEnd};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use syntect::easy::HighlightLines;
|
||||
use syntect::highlighting::{Theme, ThemeSet};
|
||||
use syntect::parsing::SyntaxSet;
|
||||
use syntect::util::LinesWithEndings;
|
||||
|
||||
/// Highlighter for syntax highlighting code blocks
|
||||
pub struct SyntaxHighlighter {
|
||||
syntax_set: SyntaxSet,
|
||||
theme: Theme,
|
||||
}
|
||||
|
||||
impl SyntaxHighlighter {
|
||||
/// Create a new syntax highlighter with default theme
|
||||
pub fn new() -> Self {
|
||||
let syntax_set = SyntaxSet::load_defaults_newlines();
|
||||
let theme_set = ThemeSet::load_defaults();
|
||||
// Use a dark theme that works well in terminals
|
||||
let theme = theme_set.themes["base16-ocean.dark"].clone();
|
||||
|
||||
Self { syntax_set, theme }
|
||||
}
|
||||
|
||||
/// Create highlighter with a specific theme name
|
||||
pub fn with_theme(theme_name: &str) -> Self {
|
||||
let syntax_set = SyntaxSet::load_defaults_newlines();
|
||||
let theme_set = ThemeSet::load_defaults();
|
||||
let theme = theme_set
|
||||
.themes
|
||||
.get(theme_name)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| theme_set.themes["base16-ocean.dark"].clone());
|
||||
|
||||
Self { syntax_set, theme }
|
||||
}
|
||||
|
||||
/// Get available theme names
|
||||
pub fn available_themes() -> Vec<&'static str> {
|
||||
vec![
|
||||
"base16-ocean.dark",
|
||||
"base16-eighties.dark",
|
||||
"base16-mocha.dark",
|
||||
"base16-ocean.light",
|
||||
"InspiredGitHub",
|
||||
"Solarized (dark)",
|
||||
"Solarized (light)",
|
||||
]
|
||||
}
|
||||
|
||||
/// Highlight a code block and return styled lines
|
||||
pub fn highlight_code(&self, code: &str, language: &str) -> Vec<Line<'static>> {
|
||||
// Find syntax for the language
|
||||
let syntax = self
|
||||
.syntax_set
|
||||
.find_syntax_by_token(language)
|
||||
.or_else(|| self.syntax_set.find_syntax_by_extension(language))
|
||||
.unwrap_or_else(|| self.syntax_set.find_syntax_plain_text());
|
||||
|
||||
let mut highlighter = HighlightLines::new(syntax, &self.theme);
|
||||
let mut lines = Vec::new();
|
||||
|
||||
for line in LinesWithEndings::from(code) {
|
||||
let Ok(ranges) = highlighter.highlight_line(line, &self.syntax_set) else {
|
||||
// Fallback to plain text if highlighting fails
|
||||
lines.push(Line::from(Span::raw(line.trim_end().to_string())));
|
||||
continue;
|
||||
};
|
||||
|
||||
let spans: Vec<Span<'static>> = ranges
|
||||
.into_iter()
|
||||
.map(|(style, text)| {
|
||||
let fg = syntect_to_ratatui_color(style.foreground);
|
||||
let ratatui_style = Style::default().fg(fg);
|
||||
Span::styled(text.trim_end_matches('\n').to_string(), ratatui_style)
|
||||
})
|
||||
.collect();
|
||||
|
||||
lines.push(Line::from(spans));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SyntaxHighlighter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert syntect color to ratatui color
|
||||
fn syntect_to_ratatui_color(color: syntect::highlighting::Color) -> Color {
|
||||
Color::Rgb(color.r, color.g, color.b)
|
||||
}
|
||||
|
||||
/// Parsed markdown content ready for rendering
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct FormattedContent {
|
||||
pub lines: Vec<Line<'static>>,
|
||||
}
|
||||
|
||||
impl FormattedContent {
|
||||
/// Create empty formatted content
|
||||
pub fn empty() -> Self {
|
||||
Self { lines: Vec::new() }
|
||||
}
|
||||
|
||||
/// Get the number of lines
|
||||
pub fn len(&self) -> usize {
|
||||
self.lines.len()
|
||||
}
|
||||
|
||||
/// Check if content is empty
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.lines.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Markdown parser that converts markdown to styled ratatui lines
|
||||
pub struct MarkdownRenderer {
|
||||
highlighter: SyntaxHighlighter,
|
||||
}
|
||||
|
||||
impl MarkdownRenderer {
|
||||
/// Create a new markdown renderer
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
highlighter: SyntaxHighlighter::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create renderer with custom highlighter
|
||||
pub fn with_highlighter(highlighter: SyntaxHighlighter) -> Self {
|
||||
Self { highlighter }
|
||||
}
|
||||
|
||||
/// Render markdown text to formatted content
|
||||
pub fn render(&self, markdown: &str) -> FormattedContent {
|
||||
let parser = Parser::new(markdown);
|
||||
let mut lines: Vec<Line<'static>> = Vec::new();
|
||||
let mut current_line_spans: Vec<Span<'static>> = Vec::new();
|
||||
|
||||
// State tracking
|
||||
let mut in_code_block = false;
|
||||
let mut code_block_lang = String::new();
|
||||
let mut code_block_content = String::new();
|
||||
let mut current_style = Style::default();
|
||||
let mut list_depth: usize = 0;
|
||||
let mut ordered_list_index: Option<u64> = None;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Start(tag) => match tag {
|
||||
Tag::Heading { level, .. } => {
|
||||
// Flush current line
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Style for headings
|
||||
current_style = match level {
|
||||
pulldown_cmark::HeadingLevel::H1 => Style::default()
|
||||
.fg(Color::Cyan)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
pulldown_cmark::HeadingLevel::H2 => Style::default()
|
||||
.fg(Color::Blue)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
pulldown_cmark::HeadingLevel::H3 => Style::default()
|
||||
.fg(Color::Green)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
_ => Style::default().add_modifier(Modifier::BOLD),
|
||||
};
|
||||
// Add heading prefix
|
||||
let prefix = "#".repeat(level as usize);
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("{} ", prefix),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
}
|
||||
Tag::Paragraph => {
|
||||
// Start a new paragraph
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
Tag::CodeBlock(kind) => {
|
||||
in_code_block = true;
|
||||
code_block_content.clear();
|
||||
code_block_lang = match kind {
|
||||
CodeBlockKind::Fenced(lang) => lang.to_string(),
|
||||
CodeBlockKind::Indented => String::new(),
|
||||
};
|
||||
// Flush current line and add code block header
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Add code fence line
|
||||
let fence_line = if code_block_lang.is_empty() {
|
||||
"```".to_string()
|
||||
} else {
|
||||
format!("```{}", code_block_lang)
|
||||
};
|
||||
lines.push(Line::from(Span::styled(
|
||||
fence_line,
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
}
|
||||
Tag::List(start) => {
|
||||
list_depth += 1;
|
||||
ordered_list_index = start;
|
||||
}
|
||||
Tag::Item => {
|
||||
// Flush current line
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
// Add list marker
|
||||
let indent = " ".repeat(list_depth.saturating_sub(1));
|
||||
let marker = if let Some(idx) = ordered_list_index {
|
||||
ordered_list_index = Some(idx + 1);
|
||||
format!("{}{}. ", indent, idx)
|
||||
} else {
|
||||
format!("{}- ", indent)
|
||||
};
|
||||
current_line_spans.push(Span::styled(
|
||||
marker,
|
||||
Style::default().fg(Color::Yellow),
|
||||
));
|
||||
}
|
||||
Tag::Emphasis => {
|
||||
current_style = current_style.add_modifier(Modifier::ITALIC);
|
||||
}
|
||||
Tag::Strong => {
|
||||
current_style = current_style.add_modifier(Modifier::BOLD);
|
||||
}
|
||||
Tag::Strikethrough => {
|
||||
current_style = current_style.add_modifier(Modifier::CROSSED_OUT);
|
||||
}
|
||||
Tag::Link { dest_url, .. } => {
|
||||
current_style = Style::default()
|
||||
.fg(Color::Blue)
|
||||
.add_modifier(Modifier::UNDERLINED);
|
||||
// Store URL for later
|
||||
current_line_spans.push(Span::styled(
|
||||
"[",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
// URL will be shown after link text
|
||||
code_block_content = dest_url.to_string();
|
||||
}
|
||||
Tag::BlockQuote(_) => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
current_line_spans.push(Span::styled(
|
||||
"│ ",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
current_style = Style::default().fg(Color::Gray).add_modifier(Modifier::ITALIC);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Event::End(tag_end) => match tag_end {
|
||||
TagEnd::Heading(_) => {
|
||||
current_style = Style::default();
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
TagEnd::Paragraph => {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
lines.push(Line::from("")); // Empty line after paragraph
|
||||
}
|
||||
TagEnd::CodeBlock => {
|
||||
in_code_block = false;
|
||||
// Highlight and add code content
|
||||
let highlighted =
|
||||
self.highlighter.highlight_code(&code_block_content, &code_block_lang);
|
||||
lines.extend(highlighted);
|
||||
// Add closing fence
|
||||
lines.push(Line::from(Span::styled(
|
||||
"```",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
code_block_content.clear();
|
||||
code_block_lang.clear();
|
||||
}
|
||||
TagEnd::List(_) => {
|
||||
list_depth = list_depth.saturating_sub(1);
|
||||
if list_depth == 0 {
|
||||
ordered_list_index = None;
|
||||
}
|
||||
}
|
||||
TagEnd::Item => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
TagEnd::Emphasis | TagEnd::Strong | TagEnd::Strikethrough => {
|
||||
current_style = Style::default();
|
||||
}
|
||||
TagEnd::Link => {
|
||||
current_line_spans.push(Span::styled(
|
||||
"]",
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("({})", code_block_content),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
));
|
||||
code_block_content.clear();
|
||||
current_style = Style::default();
|
||||
}
|
||||
TagEnd::BlockQuote => {
|
||||
current_style = Style::default();
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
Event::Text(text) => {
|
||||
if in_code_block {
|
||||
code_block_content.push_str(&text);
|
||||
} else {
|
||||
current_line_spans.push(Span::styled(text.to_string(), current_style));
|
||||
}
|
||||
}
|
||||
Event::Code(code) => {
|
||||
// Inline code
|
||||
current_line_spans.push(Span::styled(
|
||||
format!("`{}`", code),
|
||||
Style::default().fg(Color::Magenta),
|
||||
));
|
||||
}
|
||||
Event::SoftBreak => {
|
||||
current_line_spans.push(Span::raw(" "));
|
||||
}
|
||||
Event::HardBreak => {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
Event::Rule => {
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(std::mem::take(&mut current_line_spans)));
|
||||
}
|
||||
lines.push(Line::from(Span::styled(
|
||||
"─".repeat(40),
|
||||
Style::default().fg(Color::DarkGray),
|
||||
)));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Flush any remaining content
|
||||
if !current_line_spans.is_empty() {
|
||||
lines.push(Line::from(current_line_spans));
|
||||
}
|
||||
|
||||
FormattedContent { lines }
|
||||
}
|
||||
|
||||
/// Render plain text (no markdown parsing)
|
||||
pub fn render_plain(&self, text: &str) -> FormattedContent {
|
||||
let lines = text
|
||||
.lines()
|
||||
.map(|line| Line::from(Span::raw(line.to_string())))
|
||||
.collect();
|
||||
FormattedContent { lines }
|
||||
}
|
||||
|
||||
/// Render a diff with +/- highlighting
|
||||
pub fn render_diff(&self, diff: &str) -> FormattedContent {
|
||||
let lines = diff
|
||||
.lines()
|
||||
.map(|line| {
|
||||
let style = if line.starts_with('+') && !line.starts_with("+++") {
|
||||
Style::default().fg(Color::Green)
|
||||
} else if line.starts_with('-') && !line.starts_with("---") {
|
||||
Style::default().fg(Color::Red)
|
||||
} else if line.starts_with("@@") {
|
||||
Style::default().fg(Color::Cyan)
|
||||
} else if line.starts_with("diff ") || line.starts_with("index ") {
|
||||
Style::default().fg(Color::Yellow)
|
||||
} else {
|
||||
Style::default()
|
||||
};
|
||||
Line::from(Span::styled(line.to_string(), style))
|
||||
})
|
||||
.collect();
|
||||
FormattedContent { lines }
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for MarkdownRenderer {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Format a file path with syntax highlighting based on extension
|
||||
pub fn format_file_path(path: &str) -> Span<'static> {
|
||||
let color = if path.ends_with(".rs") {
|
||||
Color::Rgb(222, 165, 132) // Rust orange
|
||||
} else if path.ends_with(".toml") {
|
||||
Color::Rgb(156, 220, 254) // Light blue
|
||||
} else if path.ends_with(".md") {
|
||||
Color::Rgb(86, 156, 214) // Blue
|
||||
} else if path.ends_with(".json") {
|
||||
Color::Rgb(206, 145, 120) // Brown
|
||||
} else if path.ends_with(".ts") || path.ends_with(".tsx") {
|
||||
Color::Rgb(49, 120, 198) // TypeScript blue
|
||||
} else if path.ends_with(".js") || path.ends_with(".jsx") {
|
||||
Color::Rgb(241, 224, 90) // JavaScript yellow
|
||||
} else if path.ends_with(".py") {
|
||||
Color::Rgb(55, 118, 171) // Python blue
|
||||
} else if path.ends_with(".go") {
|
||||
Color::Rgb(0, 173, 216) // Go cyan
|
||||
} else if path.ends_with(".sh") || path.ends_with(".bash") {
|
||||
Color::Rgb(137, 224, 81) // Shell green
|
||||
} else {
|
||||
Color::White
|
||||
};
|
||||
|
||||
Span::styled(path.to_string(), Style::default().fg(color))
|
||||
}
|
||||
|
||||
/// Format a tool name with appropriate styling
|
||||
pub fn format_tool_name(name: &str) -> Span<'static> {
|
||||
let style = Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD);
|
||||
Span::styled(name.to_string(), style)
|
||||
}
|
||||
|
||||
/// Format an error message
|
||||
pub fn format_error(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("Error: ", Style::default().fg(Color::Red).add_modifier(Modifier::BOLD)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Red)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format a success message
|
||||
pub fn format_success(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("✓ ", Style::default().fg(Color::Green)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Green)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format a warning message
|
||||
pub fn format_warning(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("⚠ ", Style::default().fg(Color::Yellow)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Yellow)),
|
||||
])
|
||||
}
|
||||
|
||||
/// Format an info message
|
||||
pub fn format_info(message: &str) -> Line<'static> {
|
||||
Line::from(vec![
|
||||
Span::styled("ℹ ", Style::default().fg(Color::Blue)),
|
||||
Span::styled(message.to_string(), Style::default().fg(Color::Blue)),
|
||||
])
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_syntax_highlighter_creation() {
|
||||
let highlighter = SyntaxHighlighter::new();
|
||||
let lines = highlighter.highlight_code("fn main() {}", "rust");
|
||||
assert!(!lines.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_heading() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("# Hello World");
|
||||
assert!(!content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_code_block() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("```rust\nfn main() {}\n```");
|
||||
assert!(content.len() >= 3); // Opening fence, code, closing fence
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_markdown_render_list() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let content = renderer.render("- Item 1\n- Item 2\n- Item 3");
|
||||
assert!(content.len() >= 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_diff_rendering() {
|
||||
let renderer = MarkdownRenderer::new();
|
||||
let diff = "+added line\n-removed line\n unchanged";
|
||||
let content = renderer.render_diff(diff);
|
||||
assert_eq!(content.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_file_path() {
|
||||
let span = format_file_path("src/main.rs");
|
||||
assert!(span.content.contains("main.rs"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_messages() {
|
||||
let error = format_error("Something went wrong");
|
||||
assert!(!error.spans.is_empty());
|
||||
|
||||
let success = format_success("Operation completed");
|
||||
assert!(!success.spans.is_empty());
|
||||
|
||||
let warning = format_warning("Be careful");
|
||||
assert!(!warning.spans.is_empty());
|
||||
|
||||
let info = format_info("FYI");
|
||||
assert!(!info.spans.is_empty());
|
||||
}
|
||||
}
|
||||
218
crates/app/ui/src/layout.rs
Normal file
218
crates/app/ui/src/layout.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Layout calculation for the borderless TUI
|
||||
//!
|
||||
//! Uses vertical layout with whitespace for visual hierarchy instead of borders:
|
||||
//! - Header row (app name, mode, model, help)
|
||||
//! - Provider tabs
|
||||
//! - Horizontal divider
|
||||
//! - Chat area (scrollable)
|
||||
//! - Horizontal divider
|
||||
//! - Input area
|
||||
//! - Status bar
|
||||
|
||||
use ratatui::layout::{Constraint, Direction, Layout, Rect};
|
||||
|
||||
/// Calculated layout areas for the borderless TUI
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct AppLayout {
|
||||
/// Header row: app name, mode indicator, model, help hint
|
||||
pub header_area: Rect,
|
||||
/// Provider tabs row
|
||||
pub tabs_area: Rect,
|
||||
/// Top divider (horizontal rule)
|
||||
pub top_divider: Rect,
|
||||
/// Main chat/message area
|
||||
pub chat_area: Rect,
|
||||
/// Todo panel area (optional, between chat and input)
|
||||
pub todo_area: Rect,
|
||||
/// Bottom divider (horizontal rule)
|
||||
pub bottom_divider: Rect,
|
||||
/// Input area for user text
|
||||
pub input_area: Rect,
|
||||
/// Status bar at the bottom
|
||||
pub status_area: Rect,
|
||||
}
|
||||
|
||||
impl AppLayout {
|
||||
/// Calculate layout for the given terminal size
|
||||
pub fn calculate(area: Rect) -> Self {
|
||||
Self::calculate_with_todo(area, 0)
|
||||
}
|
||||
|
||||
/// Calculate layout with todo panel of specified height
|
||||
///
|
||||
/// Simplified layout without provider tabs:
|
||||
/// - Header (1 line)
|
||||
/// - Top divider (1 line)
|
||||
/// - Chat area (flexible)
|
||||
/// - Todo panel (optional)
|
||||
/// - Bottom divider (1 line)
|
||||
/// - Input (1 line)
|
||||
/// - Status bar (1 line)
|
||||
pub fn calculate_with_todo(area: Rect, todo_height: u16) -> Self {
|
||||
let chunks = if todo_height > 0 {
|
||||
Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(todo_height), // Todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area)
|
||||
} else {
|
||||
Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area)
|
||||
};
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(), // Not used in simplified layout
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate layout with expanded input (multiline)
|
||||
pub fn calculate_expanded_input(area: Rect, input_lines: u16) -> Self {
|
||||
let input_height = input_lines.min(10).max(1); // Cap at 10 lines
|
||||
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(input_height), // Expanded input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(),
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate layout without tabs (compact mode)
|
||||
pub fn calculate_compact(area: Rect) -> Self {
|
||||
let chunks = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length(1), // Header (includes compact provider indicator)
|
||||
Constraint::Length(1), // Top divider
|
||||
Constraint::Min(5), // Chat area (flexible)
|
||||
Constraint::Length(0), // No todo panel
|
||||
Constraint::Length(1), // Bottom divider
|
||||
Constraint::Length(1), // Input
|
||||
Constraint::Length(1), // Status bar
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Self {
|
||||
header_area: chunks[0],
|
||||
tabs_area: Rect::default(), // No tabs area in compact mode
|
||||
top_divider: chunks[1],
|
||||
chat_area: chunks[2],
|
||||
todo_area: chunks[3],
|
||||
bottom_divider: chunks[4],
|
||||
input_area: chunks[5],
|
||||
status_area: chunks[6],
|
||||
}
|
||||
}
|
||||
|
||||
/// Center a popup in the given area
|
||||
pub fn center_popup(area: Rect, width: u16, height: u16) -> Rect {
|
||||
let popup_layout = Layout::default()
|
||||
.direction(Direction::Vertical)
|
||||
.constraints([
|
||||
Constraint::Length((area.height.saturating_sub(height)) / 2),
|
||||
Constraint::Length(height),
|
||||
Constraint::Length((area.height.saturating_sub(height)) / 2),
|
||||
])
|
||||
.split(area);
|
||||
|
||||
Layout::default()
|
||||
.direction(Direction::Horizontal)
|
||||
.constraints([
|
||||
Constraint::Length((area.width.saturating_sub(width)) / 2),
|
||||
Constraint::Length(width),
|
||||
Constraint::Length((area.width.saturating_sub(width)) / 2),
|
||||
])
|
||||
.split(popup_layout[1])[1]
|
||||
}
|
||||
}
|
||||
|
||||
/// Layout mode based on terminal width
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum LayoutMode {
|
||||
/// Full layout with provider tabs (>= 80 cols)
|
||||
Full,
|
||||
/// Compact layout without tabs (< 80 cols)
|
||||
Compact,
|
||||
}
|
||||
|
||||
impl LayoutMode {
|
||||
/// Determine layout mode based on terminal width
|
||||
pub fn for_width(width: u16) -> Self {
|
||||
if width >= 80 {
|
||||
LayoutMode::Full
|
||||
} else {
|
||||
LayoutMode::Compact
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_layout_calculation() {
|
||||
let area = Rect::new(0, 0, 120, 40);
|
||||
let layout = AppLayout::calculate(area);
|
||||
|
||||
// Header should be at top
|
||||
assert_eq!(layout.header_area.y, 0);
|
||||
assert_eq!(layout.header_area.height, 1);
|
||||
|
||||
// Status should be at bottom
|
||||
assert_eq!(layout.status_area.y, 39);
|
||||
assert_eq!(layout.status_area.height, 1);
|
||||
|
||||
// Chat area should have most of the space
|
||||
assert!(layout.chat_area.height > 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_layout_mode() {
|
||||
assert_eq!(LayoutMode::for_width(80), LayoutMode::Full);
|
||||
assert_eq!(LayoutMode::for_width(120), LayoutMode::Full);
|
||||
assert_eq!(LayoutMode::for_width(79), LayoutMode::Compact);
|
||||
assert_eq!(LayoutMode::for_width(60), LayoutMode::Compact);
|
||||
}
|
||||
}
|
||||
30
crates/app/ui/src/lib.rs
Normal file
30
crates/app/ui/src/lib.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
pub mod app;
|
||||
pub mod completions;
|
||||
pub mod components;
|
||||
pub mod events;
|
||||
pub mod formatting;
|
||||
pub mod layout;
|
||||
pub mod output;
|
||||
pub mod theme;
|
||||
|
||||
pub use app::TuiApp;
|
||||
pub use completions::{CompletionEngine, Completion, CommandInfo};
|
||||
pub use events::AppEvent;
|
||||
pub use output::{CommandOutput, OutputFormat, TreeNode, ListItem};
|
||||
pub use formatting::{
|
||||
FormattedContent, MarkdownRenderer, SyntaxHighlighter,
|
||||
format_file_path, format_tool_name, format_error, format_success, format_warning, format_info,
|
||||
};
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
|
||||
/// Run the TUI application
|
||||
pub async fn run(
|
||||
client: llm_ollama::OllamaClient,
|
||||
opts: llm_core::ChatOptions,
|
||||
perms: permissions::PermissionManager,
|
||||
settings: config_agent::Settings,
|
||||
) -> Result<()> {
|
||||
let mut app = TuiApp::new(client, opts, perms, settings)?;
|
||||
app.run().await
|
||||
}
|
||||
388
crates/app/ui/src/output.rs
Normal file
388
crates/app/ui/src/output.rs
Normal file
@@ -0,0 +1,388 @@
|
||||
//! Rich command output formatting
|
||||
//!
|
||||
//! Provides formatted output for commands like /help, /mcp, /hooks
|
||||
//! with tables, trees, and syntax highlighting.
|
||||
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
|
||||
use crate::completions::CommandInfo;
|
||||
use crate::theme::Theme;
|
||||
|
||||
/// A tree node for hierarchical display
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TreeNode {
|
||||
pub label: String,
|
||||
pub children: Vec<TreeNode>,
|
||||
}
|
||||
|
||||
impl TreeNode {
|
||||
pub fn new(label: impl Into<String>) -> Self {
|
||||
Self {
|
||||
label: label.into(),
|
||||
children: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_children(mut self, children: Vec<TreeNode>) -> Self {
|
||||
self.children = children;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// A list item with optional icon/marker
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ListItem {
|
||||
pub text: String,
|
||||
pub marker: Option<String>,
|
||||
pub style: Option<Style>,
|
||||
}
|
||||
|
||||
/// Different output formats
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum OutputFormat {
|
||||
/// Formatted table with headers and rows
|
||||
Table {
|
||||
headers: Vec<String>,
|
||||
rows: Vec<Vec<String>>,
|
||||
},
|
||||
/// Hierarchical tree view
|
||||
Tree {
|
||||
root: TreeNode,
|
||||
},
|
||||
/// Syntax-highlighted code block
|
||||
Code {
|
||||
language: String,
|
||||
content: String,
|
||||
},
|
||||
/// Side-by-side diff view
|
||||
Diff {
|
||||
old: String,
|
||||
new: String,
|
||||
},
|
||||
/// Simple list with markers
|
||||
List {
|
||||
items: Vec<ListItem>,
|
||||
},
|
||||
/// Plain text
|
||||
Text {
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Rich command output renderer
|
||||
pub struct CommandOutput {
|
||||
pub format: OutputFormat,
|
||||
}
|
||||
|
||||
impl CommandOutput {
|
||||
pub fn new(format: OutputFormat) -> Self {
|
||||
Self { format }
|
||||
}
|
||||
|
||||
/// Create a help table output
|
||||
pub fn help_table(commands: &[CommandInfo]) -> Self {
|
||||
let headers = vec![
|
||||
"Command".to_string(),
|
||||
"Description".to_string(),
|
||||
"Source".to_string(),
|
||||
];
|
||||
|
||||
let rows: Vec<Vec<String>> = commands
|
||||
.iter()
|
||||
.map(|c| vec![
|
||||
format!("/{}", c.name),
|
||||
c.description.clone(),
|
||||
c.source.clone(),
|
||||
])
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::Table { headers, rows },
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an MCP servers tree view
|
||||
pub fn mcp_tree(servers: &[(String, Vec<String>)]) -> Self {
|
||||
let children: Vec<TreeNode> = servers
|
||||
.iter()
|
||||
.map(|(name, tools)| {
|
||||
TreeNode {
|
||||
label: name.clone(),
|
||||
children: tools.iter().map(|t| TreeNode::new(t)).collect(),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::Tree {
|
||||
root: TreeNode {
|
||||
label: "MCP Servers".to_string(),
|
||||
children,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a hooks list output
|
||||
pub fn hooks_list(hooks: &[(String, String, bool)]) -> Self {
|
||||
let items: Vec<ListItem> = hooks
|
||||
.iter()
|
||||
.map(|(event, path, enabled)| {
|
||||
let marker = if *enabled { "✓" } else { "✗" };
|
||||
let style = if *enabled {
|
||||
Some(Style::default().fg(Color::Green))
|
||||
} else {
|
||||
Some(Style::default().fg(Color::Red))
|
||||
};
|
||||
ListItem {
|
||||
text: format!("{}: {}", event, path),
|
||||
marker: Some(marker.to_string()),
|
||||
style,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
format: OutputFormat::List { items },
|
||||
}
|
||||
}
|
||||
|
||||
/// Render to TUI Lines
|
||||
pub fn render(&self, theme: &Theme) -> Vec<Line<'static>> {
|
||||
match &self.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
self.render_table(headers, rows, theme)
|
||||
}
|
||||
OutputFormat::Tree { root } => {
|
||||
self.render_tree(root, 0, theme)
|
||||
}
|
||||
OutputFormat::List { items } => {
|
||||
self.render_list(items, theme)
|
||||
}
|
||||
OutputFormat::Code { content, .. } => {
|
||||
content.lines()
|
||||
.map(|line| Line::from(Span::styled(line.to_string(), theme.tool_call)))
|
||||
.collect()
|
||||
}
|
||||
OutputFormat::Diff { old, new } => {
|
||||
self.render_diff(old, new, theme)
|
||||
}
|
||||
OutputFormat::Text { content } => {
|
||||
content.lines()
|
||||
.map(|line| Line::from(line.to_string()))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn render_table(&self, headers: &[String], rows: &[Vec<String>], theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Calculate column widths
|
||||
let mut widths: Vec<usize> = headers.iter().map(|h| h.len()).collect();
|
||||
for row in rows {
|
||||
for (i, cell) in row.iter().enumerate() {
|
||||
if i < widths.len() {
|
||||
widths[i] = widths[i].max(cell.len());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Header line
|
||||
let header_spans: Vec<Span> = headers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, h)| {
|
||||
let padded = format!("{:width$}", h, width = widths.get(i).copied().unwrap_or(h.len()));
|
||||
vec![
|
||||
Span::styled(padded, Style::default().add_modifier(Modifier::BOLD)),
|
||||
Span::raw(" "),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
lines.push(Line::from(header_spans));
|
||||
|
||||
// Separator
|
||||
let sep: String = widths.iter().map(|w| "─".repeat(*w)).collect::<Vec<_>>().join("──");
|
||||
lines.push(Line::from(Span::styled(sep, theme.status_dim)));
|
||||
|
||||
// Rows
|
||||
for row in rows {
|
||||
let row_spans: Vec<Span> = row
|
||||
.iter()
|
||||
.enumerate()
|
||||
.flat_map(|(i, cell)| {
|
||||
let padded = format!("{:width$}", cell, width = widths.get(i).copied().unwrap_or(cell.len()));
|
||||
let style = if i == 0 {
|
||||
theme.status_accent // Command names in accent color
|
||||
} else {
|
||||
theme.status_bar
|
||||
};
|
||||
vec![
|
||||
Span::styled(padded, style),
|
||||
Span::raw(" "),
|
||||
]
|
||||
})
|
||||
.collect();
|
||||
lines.push(Line::from(row_spans));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn render_tree(&self, node: &TreeNode, depth: usize, theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Render current node
|
||||
let prefix = if depth == 0 {
|
||||
"".to_string()
|
||||
} else {
|
||||
format!("{}├─ ", "│ ".repeat(depth - 1))
|
||||
};
|
||||
|
||||
let style = if depth == 0 {
|
||||
Style::default().add_modifier(Modifier::BOLD)
|
||||
} else if node.children.is_empty() {
|
||||
theme.status_bar
|
||||
} else {
|
||||
theme.status_accent
|
||||
};
|
||||
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(prefix, theme.status_dim),
|
||||
Span::styled(node.label.clone(), style),
|
||||
]));
|
||||
|
||||
// Render children
|
||||
for child in &node.children {
|
||||
lines.extend(self.render_tree(child, depth + 1, theme));
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
|
||||
fn render_list(&self, items: &[ListItem], theme: &Theme) -> Vec<Line<'static>> {
|
||||
items
|
||||
.iter()
|
||||
.map(|item| {
|
||||
let marker_span = if let Some(marker) = &item.marker {
|
||||
Span::styled(
|
||||
format!("{} ", marker),
|
||||
item.style.unwrap_or(theme.status_bar),
|
||||
)
|
||||
} else {
|
||||
Span::raw("• ")
|
||||
};
|
||||
|
||||
Line::from(vec![
|
||||
marker_span,
|
||||
Span::styled(
|
||||
item.text.clone(),
|
||||
item.style.unwrap_or(theme.status_bar),
|
||||
),
|
||||
])
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn render_diff(&self, old: &str, new: &str, _theme: &Theme) -> Vec<Line<'static>> {
|
||||
let mut lines = Vec::new();
|
||||
|
||||
// Simple line-by-line diff
|
||||
let old_lines: Vec<&str> = old.lines().collect();
|
||||
let new_lines: Vec<&str> = new.lines().collect();
|
||||
|
||||
let max_len = old_lines.len().max(new_lines.len());
|
||||
|
||||
for i in 0..max_len {
|
||||
let old_line = old_lines.get(i).copied().unwrap_or("");
|
||||
let new_line = new_lines.get(i).copied().unwrap_or("");
|
||||
|
||||
if old_line != new_line {
|
||||
if !old_line.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!("- {}", old_line),
|
||||
Style::default().fg(Color::Red),
|
||||
)));
|
||||
}
|
||||
if !new_line.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
format!("+ {}", new_line),
|
||||
Style::default().fg(Color::Green),
|
||||
)));
|
||||
}
|
||||
} else {
|
||||
lines.push(Line::from(format!(" {}", old_line)));
|
||||
}
|
||||
}
|
||||
|
||||
lines
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_help_table() {
|
||||
let commands = vec![
|
||||
CommandInfo::new("help", "Show help", "builtin"),
|
||||
CommandInfo::new("clear", "Clear screen", "builtin"),
|
||||
];
|
||||
let output = CommandOutput::help_table(&commands);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::Table { headers, rows } => {
|
||||
assert_eq!(headers.len(), 3);
|
||||
assert_eq!(rows.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Table format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mcp_tree() {
|
||||
let servers = vec![
|
||||
("filesystem".to_string(), vec!["read".to_string(), "write".to_string()]),
|
||||
("database".to_string(), vec!["query".to_string()]),
|
||||
];
|
||||
let output = CommandOutput::mcp_tree(&servers);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::Tree { root } => {
|
||||
assert_eq!(root.label, "MCP Servers");
|
||||
assert_eq!(root.children.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected Tree format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hooks_list() {
|
||||
let hooks = vec![
|
||||
("PreToolUse".to_string(), "./hooks/pre".to_string(), true),
|
||||
("PostToolUse".to_string(), "./hooks/post".to_string(), false),
|
||||
];
|
||||
let output = CommandOutput::hooks_list(&hooks);
|
||||
|
||||
match output.format {
|
||||
OutputFormat::List { items } => {
|
||||
assert_eq!(items.len(), 2);
|
||||
}
|
||||
_ => panic!("Expected List format"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tree_node() {
|
||||
let node = TreeNode::new("root")
|
||||
.with_children(vec![
|
||||
TreeNode::new("child1"),
|
||||
TreeNode::new("child2"),
|
||||
]);
|
||||
assert_eq!(node.label, "root");
|
||||
assert_eq!(node.children.len(), 2);
|
||||
}
|
||||
}
|
||||
707
crates/app/ui/src/theme.rs
Normal file
707
crates/app/ui/src/theme.rs
Normal file
@@ -0,0 +1,707 @@
|
||||
//! Theme system for the borderless TUI design
|
||||
//!
|
||||
//! Provides color palettes, semantic styling, and terminal capability detection
|
||||
//! for graceful degradation across different terminal emulators.
|
||||
|
||||
use ratatui::style::{Color, Modifier, Style};
|
||||
|
||||
/// Terminal capability detection for graceful degradation
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum TerminalCapability {
|
||||
/// Full Unicode support with true color
|
||||
Full,
|
||||
/// Basic Unicode with 256 colors
|
||||
Unicode256,
|
||||
/// ASCII only with 16 colors
|
||||
Basic,
|
||||
}
|
||||
|
||||
impl TerminalCapability {
|
||||
/// Detect terminal capabilities from environment
|
||||
pub fn detect() -> Self {
|
||||
// Check for true color support
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term = std::env::var("TERM").unwrap_or_default();
|
||||
|
||||
if colorterm == "truecolor" || colorterm == "24bit" {
|
||||
return Self::Full;
|
||||
}
|
||||
|
||||
if term.contains("256color") || term.contains("kitty") || term.contains("alacritty") {
|
||||
return Self::Unicode256;
|
||||
}
|
||||
|
||||
// Check if we're in a linux VT or basic terminal
|
||||
if term == "linux" || term == "vt100" || term == "dumb" {
|
||||
return Self::Basic;
|
||||
}
|
||||
|
||||
// Default to unicode with 256 colors
|
||||
Self::Unicode256
|
||||
}
|
||||
|
||||
/// Check if Unicode box drawing is supported
|
||||
pub fn supports_unicode(&self) -> bool {
|
||||
matches!(self, Self::Full | Self::Unicode256)
|
||||
}
|
||||
|
||||
/// Check if true color (RGB) is supported
|
||||
pub fn supports_truecolor(&self) -> bool {
|
||||
matches!(self, Self::Full)
|
||||
}
|
||||
}
|
||||
|
||||
/// Symbols with fallbacks for different terminal capabilities
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Symbols {
|
||||
pub horizontal_rule: &'static str,
|
||||
pub vertical_separator: &'static str,
|
||||
pub bullet: &'static str,
|
||||
pub arrow: &'static str,
|
||||
pub check: &'static str,
|
||||
pub cross: &'static str,
|
||||
pub warning: &'static str,
|
||||
pub info: &'static str,
|
||||
pub streaming: &'static str,
|
||||
pub user_prefix: &'static str,
|
||||
pub assistant_prefix: &'static str,
|
||||
pub tool_prefix: &'static str,
|
||||
pub system_prefix: &'static str,
|
||||
// Provider icons
|
||||
pub claude_icon: &'static str,
|
||||
pub ollama_icon: &'static str,
|
||||
pub openai_icon: &'static str,
|
||||
// Vim mode indicators
|
||||
pub mode_normal: &'static str,
|
||||
pub mode_insert: &'static str,
|
||||
pub mode_visual: &'static str,
|
||||
pub mode_command: &'static str,
|
||||
}
|
||||
|
||||
impl Symbols {
|
||||
/// Unicode symbols for capable terminals
|
||||
pub fn unicode() -> Self {
|
||||
Self {
|
||||
horizontal_rule: "─",
|
||||
vertical_separator: "│",
|
||||
bullet: "•",
|
||||
arrow: "→",
|
||||
check: "✓",
|
||||
cross: "✗",
|
||||
warning: "⚠",
|
||||
info: "ℹ",
|
||||
streaming: "●",
|
||||
user_prefix: "❯",
|
||||
assistant_prefix: "◆",
|
||||
tool_prefix: "⚡",
|
||||
system_prefix: "○",
|
||||
claude_icon: "",
|
||||
ollama_icon: "",
|
||||
openai_icon: "",
|
||||
mode_normal: "[N]",
|
||||
mode_insert: "[I]",
|
||||
mode_visual: "[V]",
|
||||
mode_command: "[:]",
|
||||
}
|
||||
}
|
||||
|
||||
/// ASCII fallback symbols
|
||||
pub fn ascii() -> Self {
|
||||
Self {
|
||||
horizontal_rule: "-",
|
||||
vertical_separator: "|",
|
||||
bullet: "*",
|
||||
arrow: "->",
|
||||
check: "+",
|
||||
cross: "x",
|
||||
warning: "!",
|
||||
info: "i",
|
||||
streaming: "*",
|
||||
user_prefix: ">",
|
||||
assistant_prefix: "-",
|
||||
tool_prefix: "#",
|
||||
system_prefix: "-",
|
||||
claude_icon: "C",
|
||||
ollama_icon: "O",
|
||||
openai_icon: "G",
|
||||
mode_normal: "[N]",
|
||||
mode_insert: "[I]",
|
||||
mode_visual: "[V]",
|
||||
mode_command: "[:]",
|
||||
}
|
||||
}
|
||||
|
||||
/// Select symbols based on terminal capability
|
||||
pub fn for_capability(cap: TerminalCapability) -> Self {
|
||||
match cap {
|
||||
TerminalCapability::Full | TerminalCapability::Unicode256 => Self::unicode(),
|
||||
TerminalCapability::Basic => Self::ascii(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Modern color palette inspired by contemporary design systems
|
||||
///
|
||||
/// Color assignment principles:
|
||||
/// - fg (#c0caf5): PRIMARY text - user messages, command names
|
||||
/// - assistant (#9aa5ce): Soft gray-blue for AI responses (distinct from user)
|
||||
/// - accent (#7aa2f7): Interactive elements ONLY (mode, prompt symbol)
|
||||
/// - cmd_slash (#bb9af7): Purple for / prefix (signals "command")
|
||||
/// - fg_dim (#565f89): Timestamps, hints, inactive elements
|
||||
/// - selection (#283457): Highlighted row background
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ColorPalette {
|
||||
pub primary: Color,
|
||||
pub secondary: Color,
|
||||
pub accent: Color,
|
||||
pub success: Color,
|
||||
pub warning: Color,
|
||||
pub error: Color,
|
||||
pub info: Color,
|
||||
pub bg: Color,
|
||||
pub fg: Color,
|
||||
pub fg_dim: Color,
|
||||
pub fg_muted: Color,
|
||||
pub highlight: Color,
|
||||
pub border: Color, // For horizontal rules (subtle)
|
||||
pub selection: Color, // Highlighted row background
|
||||
// Provider-specific colors
|
||||
pub claude: Color,
|
||||
pub ollama: Color,
|
||||
pub openai: Color,
|
||||
// Semantic colors for messages
|
||||
pub user_fg: Color, // User message text (bright, fg)
|
||||
pub assistant_fg: Color, // Assistant message text (soft gray-blue)
|
||||
pub tool_fg: Color,
|
||||
pub timestamp_fg: Color,
|
||||
pub divider_fg: Color,
|
||||
// Command colors
|
||||
pub cmd_slash: Color, // Purple for / prefix
|
||||
pub cmd_name: Color, // Command name (same as fg)
|
||||
pub cmd_desc: Color, // Command description (dim)
|
||||
// Overlay/modal colors
|
||||
pub overlay_bg: Color, // Slightly lighter than main bg
|
||||
}
|
||||
|
||||
impl ColorPalette {
|
||||
/// Tokyo Night inspired palette - high contrast, readable
|
||||
///
|
||||
/// Key principles:
|
||||
/// - fg (#c0caf5) for user messages and command names
|
||||
/// - assistant (#a9b1d6) brighter gray-blue for AI responses (readable)
|
||||
/// - accent (#7aa2f7) only for interactive elements (mode indicator, prompt symbol)
|
||||
/// - cmd_slash (#bb9af7) purple for / prefix (signals "command")
|
||||
/// - fg_dim (#737aa2) for timestamps, hints, descriptions (brighter than before)
|
||||
/// - border (#3b4261) for horizontal rules
|
||||
pub fn tokyo_night() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(122, 162, 247), // #7aa2f7 - Blue accent
|
||||
secondary: Color::Rgb(187, 154, 247), // #bb9af7 - Purple
|
||||
accent: Color::Rgb(122, 162, 247), // #7aa2f7 - Interactive elements ONLY
|
||||
success: Color::Rgb(158, 206, 106), // #9ece6a - Green
|
||||
warning: Color::Rgb(224, 175, 104), // #e0af68 - Yellow
|
||||
error: Color::Rgb(247, 118, 142), // #f7768e - Pink/Red
|
||||
info: Color::Rgb(125, 207, 255), // Cyan (rarely used)
|
||||
bg: Color::Rgb(26, 27, 38), // #1a1b26 - Dark bg
|
||||
fg: Color::Rgb(192, 202, 245), // #c0caf5 - Primary text (HIGH CONTRAST)
|
||||
fg_dim: Color::Rgb(115, 122, 162), // #737aa2 - Secondary text (BRIGHTER)
|
||||
fg_muted: Color::Rgb(86, 95, 137), // #565f89 - Very dim
|
||||
highlight: Color::Rgb(56, 62, 90), // Selection bg (legacy)
|
||||
border: Color::Rgb(73, 82, 115), // #495273 - Horizontal rules (BRIGHTER)
|
||||
selection: Color::Rgb(40, 52, 87), // #283457 - Highlighted row bg
|
||||
// Provider colors
|
||||
claude: Color::Rgb(217, 119, 87), // Claude orange
|
||||
ollama: Color::Rgb(122, 162, 247), // Blue
|
||||
openai: Color::Rgb(16, 163, 127), // OpenAI green
|
||||
// Message colors - user bright, assistant readable
|
||||
user_fg: Color::Rgb(192, 202, 245), // #c0caf5 - Same as fg (bright)
|
||||
assistant_fg: Color::Rgb(169, 177, 214), // #a9b1d6 - Brighter gray-blue (READABLE)
|
||||
tool_fg: Color::Rgb(224, 175, 104), // #e0af68 - Yellow for tools
|
||||
timestamp_fg: Color::Rgb(115, 122, 162), // #737aa2 - Brighter dim
|
||||
divider_fg: Color::Rgb(73, 82, 115), // #495273 - Border color (BRIGHTER)
|
||||
// Command colors
|
||||
cmd_slash: Color::Rgb(187, 154, 247), // #bb9af7 - Purple for / prefix
|
||||
cmd_name: Color::Rgb(192, 202, 245), // #c0caf5 - White for command name
|
||||
cmd_desc: Color::Rgb(115, 122, 162), // #737aa2 - Brighter description
|
||||
// Overlay colors
|
||||
overlay_bg: Color::Rgb(36, 40, 59), // #24283b - Slightly lighter than bg
|
||||
}
|
||||
}
|
||||
|
||||
/// Dracula inspired palette - classic and elegant
|
||||
pub fn dracula() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(139, 233, 253), // Cyan
|
||||
secondary: Color::Rgb(189, 147, 249), // Purple
|
||||
accent: Color::Rgb(255, 121, 198), // Pink
|
||||
success: Color::Rgb(80, 250, 123), // Green
|
||||
warning: Color::Rgb(241, 250, 140), // Yellow
|
||||
error: Color::Rgb(255, 85, 85), // Red
|
||||
info: Color::Rgb(139, 233, 253), // Cyan
|
||||
bg: Color::Rgb(40, 42, 54), // Dark bg
|
||||
fg: Color::Rgb(248, 248, 242), // Light text
|
||||
fg_dim: Color::Rgb(98, 114, 164), // Comment
|
||||
fg_muted: Color::Rgb(68, 71, 90), // Very dim
|
||||
highlight: Color::Rgb(68, 71, 90), // Selection
|
||||
border: Color::Rgb(68, 71, 90),
|
||||
selection: Color::Rgb(68, 71, 90),
|
||||
claude: Color::Rgb(255, 121, 198),
|
||||
ollama: Color::Rgb(139, 233, 253),
|
||||
openai: Color::Rgb(80, 250, 123),
|
||||
user_fg: Color::Rgb(248, 248, 242),
|
||||
assistant_fg: Color::Rgb(189, 186, 220), // Softer purple-gray
|
||||
tool_fg: Color::Rgb(241, 250, 140),
|
||||
timestamp_fg: Color::Rgb(68, 71, 90),
|
||||
divider_fg: Color::Rgb(68, 71, 90),
|
||||
cmd_slash: Color::Rgb(189, 147, 249), // Purple
|
||||
cmd_name: Color::Rgb(248, 248, 242),
|
||||
cmd_desc: Color::Rgb(98, 114, 164),
|
||||
overlay_bg: Color::Rgb(50, 52, 64),
|
||||
}
|
||||
}
|
||||
|
||||
/// Catppuccin Mocha - warm and cozy
|
||||
pub fn catppuccin() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(137, 180, 250), // Blue
|
||||
secondary: Color::Rgb(203, 166, 247), // Mauve
|
||||
accent: Color::Rgb(245, 194, 231), // Pink
|
||||
success: Color::Rgb(166, 227, 161), // Green
|
||||
warning: Color::Rgb(249, 226, 175), // Yellow
|
||||
error: Color::Rgb(243, 139, 168), // Red
|
||||
info: Color::Rgb(148, 226, 213), // Teal
|
||||
bg: Color::Rgb(30, 30, 46), // Base
|
||||
fg: Color::Rgb(205, 214, 244), // Text
|
||||
fg_dim: Color::Rgb(108, 112, 134), // Overlay
|
||||
fg_muted: Color::Rgb(69, 71, 90), // Surface
|
||||
highlight: Color::Rgb(49, 50, 68), // Surface
|
||||
border: Color::Rgb(69, 71, 90),
|
||||
selection: Color::Rgb(49, 50, 68),
|
||||
claude: Color::Rgb(245, 194, 231),
|
||||
ollama: Color::Rgb(137, 180, 250),
|
||||
openai: Color::Rgb(166, 227, 161),
|
||||
user_fg: Color::Rgb(205, 214, 244),
|
||||
assistant_fg: Color::Rgb(166, 187, 213), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(249, 226, 175),
|
||||
timestamp_fg: Color::Rgb(69, 71, 90),
|
||||
divider_fg: Color::Rgb(69, 71, 90),
|
||||
cmd_slash: Color::Rgb(203, 166, 247), // Mauve
|
||||
cmd_name: Color::Rgb(205, 214, 244),
|
||||
cmd_desc: Color::Rgb(108, 112, 134),
|
||||
overlay_bg: Color::Rgb(40, 40, 56),
|
||||
}
|
||||
}
|
||||
|
||||
/// Nord - minimal and clean
|
||||
pub fn nord() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(136, 192, 208), // Frost cyan
|
||||
secondary: Color::Rgb(129, 161, 193), // Frost blue
|
||||
accent: Color::Rgb(180, 142, 173), // Aurora purple
|
||||
success: Color::Rgb(163, 190, 140), // Aurora green
|
||||
warning: Color::Rgb(235, 203, 139), // Aurora yellow
|
||||
error: Color::Rgb(191, 97, 106), // Aurora red
|
||||
info: Color::Rgb(136, 192, 208), // Frost cyan
|
||||
bg: Color::Rgb(46, 52, 64), // Polar night
|
||||
fg: Color::Rgb(236, 239, 244), // Snow storm
|
||||
fg_dim: Color::Rgb(76, 86, 106), // Polar night light
|
||||
fg_muted: Color::Rgb(59, 66, 82),
|
||||
highlight: Color::Rgb(59, 66, 82), // Selection
|
||||
border: Color::Rgb(59, 66, 82),
|
||||
selection: Color::Rgb(59, 66, 82),
|
||||
claude: Color::Rgb(180, 142, 173),
|
||||
ollama: Color::Rgb(136, 192, 208),
|
||||
openai: Color::Rgb(163, 190, 140),
|
||||
user_fg: Color::Rgb(236, 239, 244),
|
||||
assistant_fg: Color::Rgb(180, 195, 210), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(235, 203, 139),
|
||||
timestamp_fg: Color::Rgb(59, 66, 82),
|
||||
divider_fg: Color::Rgb(59, 66, 82),
|
||||
cmd_slash: Color::Rgb(180, 142, 173), // Aurora purple
|
||||
cmd_name: Color::Rgb(236, 239, 244),
|
||||
cmd_desc: Color::Rgb(76, 86, 106),
|
||||
overlay_bg: Color::Rgb(56, 62, 74),
|
||||
}
|
||||
}
|
||||
|
||||
/// Synthwave - vibrant and retro
|
||||
pub fn synthwave() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(255, 0, 128), // Hot pink
|
||||
secondary: Color::Rgb(0, 229, 255), // Cyan
|
||||
accent: Color::Rgb(255, 128, 0), // Orange
|
||||
success: Color::Rgb(0, 255, 157), // Neon green
|
||||
warning: Color::Rgb(255, 215, 0), // Gold
|
||||
error: Color::Rgb(255, 64, 64), // Neon red
|
||||
info: Color::Rgb(0, 229, 255), // Cyan
|
||||
bg: Color::Rgb(20, 16, 32), // Dark purple
|
||||
fg: Color::Rgb(242, 233, 255), // Light purple
|
||||
fg_dim: Color::Rgb(127, 90, 180), // Mid purple
|
||||
fg_muted: Color::Rgb(72, 12, 168),
|
||||
highlight: Color::Rgb(72, 12, 168), // Deep purple
|
||||
border: Color::Rgb(72, 12, 168),
|
||||
selection: Color::Rgb(72, 12, 168),
|
||||
claude: Color::Rgb(255, 128, 0),
|
||||
ollama: Color::Rgb(0, 229, 255),
|
||||
openai: Color::Rgb(0, 255, 157),
|
||||
user_fg: Color::Rgb(242, 233, 255),
|
||||
assistant_fg: Color::Rgb(180, 170, 220), // Softer purple
|
||||
tool_fg: Color::Rgb(255, 215, 0),
|
||||
timestamp_fg: Color::Rgb(72, 12, 168),
|
||||
divider_fg: Color::Rgb(72, 12, 168),
|
||||
cmd_slash: Color::Rgb(255, 0, 128), // Hot pink
|
||||
cmd_name: Color::Rgb(242, 233, 255),
|
||||
cmd_desc: Color::Rgb(127, 90, 180),
|
||||
overlay_bg: Color::Rgb(30, 26, 42),
|
||||
}
|
||||
}
|
||||
|
||||
/// Rose Pine - elegant and muted
|
||||
pub fn rose_pine() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(156, 207, 216), // Foam
|
||||
secondary: Color::Rgb(235, 188, 186), // Rose
|
||||
accent: Color::Rgb(234, 154, 151), // Love
|
||||
success: Color::Rgb(49, 116, 143), // Pine
|
||||
warning: Color::Rgb(246, 193, 119), // Gold
|
||||
error: Color::Rgb(235, 111, 146), // Love (darker)
|
||||
info: Color::Rgb(156, 207, 216), // Foam
|
||||
bg: Color::Rgb(25, 23, 36), // Base
|
||||
fg: Color::Rgb(224, 222, 244), // Text
|
||||
fg_dim: Color::Rgb(110, 106, 134), // Muted
|
||||
fg_muted: Color::Rgb(42, 39, 63),
|
||||
highlight: Color::Rgb(42, 39, 63), // Highlight
|
||||
border: Color::Rgb(42, 39, 63),
|
||||
selection: Color::Rgb(42, 39, 63),
|
||||
claude: Color::Rgb(234, 154, 151),
|
||||
ollama: Color::Rgb(156, 207, 216),
|
||||
openai: Color::Rgb(49, 116, 143),
|
||||
user_fg: Color::Rgb(224, 222, 244),
|
||||
assistant_fg: Color::Rgb(180, 185, 210), // Softer lavender-gray
|
||||
tool_fg: Color::Rgb(246, 193, 119),
|
||||
timestamp_fg: Color::Rgb(42, 39, 63),
|
||||
divider_fg: Color::Rgb(42, 39, 63),
|
||||
cmd_slash: Color::Rgb(235, 188, 186), // Rose
|
||||
cmd_name: Color::Rgb(224, 222, 244),
|
||||
cmd_desc: Color::Rgb(110, 106, 134),
|
||||
overlay_bg: Color::Rgb(35, 33, 46),
|
||||
}
|
||||
}
|
||||
|
||||
/// Midnight Ocean - deep and serene
|
||||
pub fn midnight_ocean() -> Self {
|
||||
Self {
|
||||
primary: Color::Rgb(102, 217, 239), // Bright cyan
|
||||
secondary: Color::Rgb(130, 170, 255), // Periwinkle
|
||||
accent: Color::Rgb(199, 146, 234), // Purple
|
||||
success: Color::Rgb(163, 190, 140), // Sea green
|
||||
warning: Color::Rgb(229, 200, 144), // Sandy yellow
|
||||
error: Color::Rgb(236, 95, 103), // Coral red
|
||||
info: Color::Rgb(102, 217, 239), // Bright cyan
|
||||
bg: Color::Rgb(1, 22, 39), // Deep ocean
|
||||
fg: Color::Rgb(201, 211, 235), // Light blue-white
|
||||
fg_dim: Color::Rgb(71, 103, 145), // Muted blue
|
||||
fg_muted: Color::Rgb(13, 43, 69),
|
||||
highlight: Color::Rgb(13, 43, 69), // Deep blue
|
||||
border: Color::Rgb(13, 43, 69),
|
||||
selection: Color::Rgb(13, 43, 69),
|
||||
claude: Color::Rgb(199, 146, 234),
|
||||
ollama: Color::Rgb(102, 217, 239),
|
||||
openai: Color::Rgb(163, 190, 140),
|
||||
user_fg: Color::Rgb(201, 211, 235),
|
||||
assistant_fg: Color::Rgb(150, 175, 200), // Softer blue-gray
|
||||
tool_fg: Color::Rgb(229, 200, 144),
|
||||
timestamp_fg: Color::Rgb(13, 43, 69),
|
||||
divider_fg: Color::Rgb(13, 43, 69),
|
||||
cmd_slash: Color::Rgb(199, 146, 234), // Purple
|
||||
cmd_name: Color::Rgb(201, 211, 235),
|
||||
cmd_desc: Color::Rgb(71, 103, 145),
|
||||
overlay_bg: Color::Rgb(11, 32, 49),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// LLM Provider enum
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Provider {
|
||||
Claude,
|
||||
Ollama,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
pub fn name(&self) -> &'static str {
|
||||
match self {
|
||||
Provider::Claude => "Claude",
|
||||
Provider::Ollama => "Ollama",
|
||||
Provider::OpenAI => "OpenAI",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn all() -> &'static [Provider] {
|
||||
&[Provider::Claude, Provider::Ollama, Provider::OpenAI]
|
||||
}
|
||||
}
|
||||
|
||||
/// Vim-like editing mode
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
|
||||
pub enum VimMode {
|
||||
#[default]
|
||||
Normal,
|
||||
Insert,
|
||||
Visual,
|
||||
Command,
|
||||
}
|
||||
|
||||
impl VimMode {
|
||||
pub fn indicator(&self, symbols: &Symbols) -> &'static str {
|
||||
match self {
|
||||
VimMode::Normal => symbols.mode_normal,
|
||||
VimMode::Insert => symbols.mode_insert,
|
||||
VimMode::Visual => symbols.mode_visual,
|
||||
VimMode::Command => symbols.mode_command,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Theme configuration for the borderless TUI
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Theme {
|
||||
pub palette: ColorPalette,
|
||||
pub symbols: Symbols,
|
||||
pub capability: TerminalCapability,
|
||||
// Message styles
|
||||
pub user_message: Style,
|
||||
pub assistant_message: Style,
|
||||
pub tool_call: Style,
|
||||
pub tool_result_success: Style,
|
||||
pub tool_result_error: Style,
|
||||
pub system_message: Style,
|
||||
pub timestamp: Style,
|
||||
// UI element styles
|
||||
pub divider: Style,
|
||||
pub header: Style,
|
||||
pub header_accent: Style,
|
||||
pub tab_active: Style,
|
||||
pub tab_inactive: Style,
|
||||
pub input_prefix: Style,
|
||||
pub input_text: Style,
|
||||
pub input_placeholder: Style,
|
||||
pub status_bar: Style,
|
||||
pub status_accent: Style,
|
||||
pub status_dim: Style,
|
||||
// Command styles
|
||||
pub cmd_slash: Style, // Purple for / prefix
|
||||
pub cmd_name: Style, // White for command name
|
||||
pub cmd_desc: Style, // Dim for description
|
||||
// Overlay/modal styles
|
||||
pub overlay_bg: Style, // Modal background
|
||||
pub selection_bg: Style, // Selected row background
|
||||
// Popup styles (for permission dialogs)
|
||||
pub popup_border: Style,
|
||||
pub popup_bg: Style,
|
||||
pub popup_title: Style,
|
||||
pub selected: Style,
|
||||
// Legacy compatibility
|
||||
pub border: Style,
|
||||
pub border_active: Style,
|
||||
pub status_bar_highlight: Style,
|
||||
pub input_box: Style,
|
||||
pub input_box_active: Style,
|
||||
}
|
||||
|
||||
impl Theme {
|
||||
/// Create theme from color palette with automatic capability detection
|
||||
pub fn from_palette(palette: ColorPalette) -> Self {
|
||||
let capability = TerminalCapability::detect();
|
||||
Self::from_palette_with_capability(palette, capability)
|
||||
}
|
||||
|
||||
/// Create theme with specific terminal capability
|
||||
pub fn from_palette_with_capability(palette: ColorPalette, capability: TerminalCapability) -> Self {
|
||||
let symbols = Symbols::for_capability(capability);
|
||||
|
||||
Self {
|
||||
// Message styles
|
||||
user_message: Style::default()
|
||||
.fg(palette.user_fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
assistant_message: Style::default().fg(palette.assistant_fg),
|
||||
tool_call: Style::default()
|
||||
.fg(palette.tool_fg)
|
||||
.add_modifier(Modifier::ITALIC),
|
||||
tool_result_success: Style::default()
|
||||
.fg(palette.success)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
tool_result_error: Style::default()
|
||||
.fg(palette.error)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
system_message: Style::default().fg(palette.fg_dim),
|
||||
timestamp: Style::default().fg(palette.timestamp_fg),
|
||||
// UI elements
|
||||
divider: Style::default().fg(palette.divider_fg),
|
||||
header: Style::default()
|
||||
.fg(palette.fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
header_accent: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
tab_active: Style::default()
|
||||
.fg(palette.primary)
|
||||
.add_modifier(Modifier::BOLD | Modifier::UNDERLINED),
|
||||
tab_inactive: Style::default().fg(palette.fg_dim),
|
||||
input_prefix: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
input_text: Style::default().fg(palette.fg),
|
||||
input_placeholder: Style::default().fg(palette.fg_muted),
|
||||
status_bar: Style::default().fg(palette.fg_dim),
|
||||
status_accent: Style::default().fg(palette.accent),
|
||||
status_dim: Style::default().fg(palette.fg_muted),
|
||||
// Command styles
|
||||
cmd_slash: Style::default().fg(palette.cmd_slash),
|
||||
cmd_name: Style::default().fg(palette.cmd_name),
|
||||
cmd_desc: Style::default().fg(palette.cmd_desc),
|
||||
// Overlay/modal styles
|
||||
overlay_bg: Style::default().bg(palette.overlay_bg),
|
||||
selection_bg: Style::default().bg(palette.selection),
|
||||
// Popup styles
|
||||
popup_border: Style::default()
|
||||
.fg(palette.border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
popup_bg: Style::default().bg(palette.overlay_bg),
|
||||
popup_title: Style::default()
|
||||
.fg(palette.fg)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
selected: Style::default()
|
||||
.fg(palette.fg)
|
||||
.bg(palette.selection)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
// Legacy compatibility
|
||||
border: Style::default().fg(palette.fg_dim),
|
||||
border_active: Style::default()
|
||||
.fg(palette.primary)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
status_bar_highlight: Style::default()
|
||||
.fg(palette.bg)
|
||||
.bg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
input_box: Style::default().fg(palette.fg),
|
||||
input_box_active: Style::default()
|
||||
.fg(palette.accent)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
symbols,
|
||||
capability,
|
||||
palette,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get provider-specific color
|
||||
pub fn provider_color(&self, provider: Provider) -> Color {
|
||||
match provider {
|
||||
Provider::Claude => self.palette.claude,
|
||||
Provider::Ollama => self.palette.ollama,
|
||||
Provider::OpenAI => self.palette.openai,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get provider icon
|
||||
pub fn provider_icon(&self, provider: Provider) -> &str {
|
||||
match provider {
|
||||
Provider::Claude => self.symbols.claude_icon,
|
||||
Provider::Ollama => self.symbols.ollama_icon,
|
||||
Provider::OpenAI => self.symbols.openai_icon,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a horizontal rule string of given width
|
||||
pub fn horizontal_rule(&self, width: usize) -> String {
|
||||
self.symbols.horizontal_rule.repeat(width)
|
||||
}
|
||||
|
||||
/// Tokyo Night theme (default) - modern and vibrant
|
||||
pub fn tokyo_night() -> Self {
|
||||
Self::from_palette(ColorPalette::tokyo_night())
|
||||
}
|
||||
|
||||
/// Dracula theme - classic dark theme
|
||||
pub fn dracula() -> Self {
|
||||
Self::from_palette(ColorPalette::dracula())
|
||||
}
|
||||
|
||||
/// Catppuccin Mocha - warm and cozy
|
||||
pub fn catppuccin() -> Self {
|
||||
Self::from_palette(ColorPalette::catppuccin())
|
||||
}
|
||||
|
||||
/// Nord theme - minimal and clean
|
||||
pub fn nord() -> Self {
|
||||
Self::from_palette(ColorPalette::nord())
|
||||
}
|
||||
|
||||
/// Synthwave theme - vibrant retro
|
||||
pub fn synthwave() -> Self {
|
||||
Self::from_palette(ColorPalette::synthwave())
|
||||
}
|
||||
|
||||
/// Rose Pine theme - elegant and muted
|
||||
pub fn rose_pine() -> Self {
|
||||
Self::from_palette(ColorPalette::rose_pine())
|
||||
}
|
||||
|
||||
/// Midnight Ocean theme - deep and serene
|
||||
pub fn midnight_ocean() -> Self {
|
||||
Self::from_palette(ColorPalette::midnight_ocean())
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Theme {
|
||||
fn default() -> Self {
|
||||
Self::tokyo_night()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_terminal_capability_detection() {
|
||||
let cap = TerminalCapability::detect();
|
||||
// Should return some valid capability
|
||||
assert!(matches!(
|
||||
cap,
|
||||
TerminalCapability::Full | TerminalCapability::Unicode256 | TerminalCapability::Basic
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_symbols_for_capability() {
|
||||
let unicode = Symbols::for_capability(TerminalCapability::Full);
|
||||
assert_eq!(unicode.horizontal_rule, "─");
|
||||
|
||||
let ascii = Symbols::for_capability(TerminalCapability::Basic);
|
||||
assert_eq!(ascii.horizontal_rule, "-");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_theme_from_palette() {
|
||||
let theme = Theme::tokyo_night();
|
||||
assert!(theme.capability.supports_unicode() || !theme.capability.supports_unicode());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_provider_colors() {
|
||||
let theme = Theme::tokyo_night();
|
||||
let claude_color = theme.provider_color(Provider::Claude);
|
||||
let ollama_color = theme.provider_color(Provider::Ollama);
|
||||
assert_ne!(claude_color, ollama_color);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_vim_mode_indicator() {
|
||||
let symbols = Symbols::unicode();
|
||||
assert_eq!(VimMode::Normal.indicator(&symbols), "[N]");
|
||||
assert_eq!(VimMode::Insert.indicator(&symbols), "[I]");
|
||||
}
|
||||
}
|
||||
29
crates/core/agent/Cargo.toml
Normal file
29
crates/core/agent/Cargo.toml
Normal file
@@ -0,0 +1,29 @@
|
||||
[package]
|
||||
name = "agent-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
color-eyre = "0.6"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
futures-util = "0.3"
|
||||
tracing = "0.1"
|
||||
async-trait = "0.1"
|
||||
chrono = "0.4"
|
||||
|
||||
# Internal dependencies
|
||||
llm-core = { path = "../../llm/core" }
|
||||
permissions = { path = "../../platform/permissions" }
|
||||
tools-fs = { path = "../../tools/fs" }
|
||||
tools-bash = { path = "../../tools/bash" }
|
||||
tools-ask = { path = "../../tools/ask" }
|
||||
tools-todo = { path = "../../tools/todo" }
|
||||
tools-web = { path = "../../tools/web" }
|
||||
tools-plan = { path = "../../tools/plan" }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.13"
|
||||
74
crates/core/agent/examples/git_demo.rs
Normal file
74
crates/core/agent/examples/git_demo.rs
Normal file
@@ -0,0 +1,74 @@
|
||||
//! Example demonstrating the git integration module
|
||||
//!
|
||||
//! Run with: cargo run -p agent-core --example git_demo
|
||||
|
||||
use agent_core::{detect_git_state, format_git_status, is_safe_git_command, is_destructive_git_command};
|
||||
use std::env;
|
||||
|
||||
fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Get current working directory
|
||||
let cwd = env::current_dir()?;
|
||||
println!("Detecting git state in: {}\n", cwd.display());
|
||||
|
||||
// Detect git state
|
||||
let state = detect_git_state(&cwd)?;
|
||||
|
||||
// Display formatted status
|
||||
println!("{}\n", format_git_status(&state));
|
||||
|
||||
// Show detailed file status if there are changes
|
||||
if !state.status.is_empty() {
|
||||
println!("Detailed file status:");
|
||||
for status in &state.status {
|
||||
match status {
|
||||
agent_core::GitFileStatus::Modified { path } => {
|
||||
println!(" M {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Added { path } => {
|
||||
println!(" A {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Deleted { path } => {
|
||||
println!(" D {}", path);
|
||||
}
|
||||
agent_core::GitFileStatus::Renamed { from, to } => {
|
||||
println!(" R {} -> {}", from, to);
|
||||
}
|
||||
agent_core::GitFileStatus::Untracked { path } => {
|
||||
println!(" ? {}", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
// Test command safety checking
|
||||
println!("Command safety checks:");
|
||||
let test_commands = vec![
|
||||
"git status",
|
||||
"git log --oneline",
|
||||
"git diff HEAD",
|
||||
"git commit -m 'test'",
|
||||
"git push --force origin main",
|
||||
"git reset --hard HEAD~1",
|
||||
"git rebase main",
|
||||
"git branch -D feature",
|
||||
];
|
||||
|
||||
for cmd in test_commands {
|
||||
let is_safe = is_safe_git_command(cmd);
|
||||
let (is_destructive, warning) = is_destructive_git_command(cmd);
|
||||
|
||||
print!(" {} - ", cmd);
|
||||
if is_safe {
|
||||
println!("SAFE (read-only)");
|
||||
} else if is_destructive {
|
||||
println!("DESTRUCTIVE: {}", warning);
|
||||
} else {
|
||||
println!("UNSAFE (modifies state)");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
92
crates/core/agent/examples/streaming_agent.rs
Normal file
92
crates/core/agent/examples/streaming_agent.rs
Normal file
@@ -0,0 +1,92 @@
|
||||
/// Example demonstrating the streaming agent loop API
|
||||
///
|
||||
/// This example shows how to use `run_agent_loop_streaming` to receive
|
||||
/// real-time events during agent execution, including:
|
||||
/// - Text deltas as the LLM generates text
|
||||
/// - Tool execution start/end events
|
||||
/// - Tool output events
|
||||
/// - Final completion events
|
||||
///
|
||||
/// Run with: cargo run --example streaming_agent -p agent-core
|
||||
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use llm_core::ChatOptions;
|
||||
use permissions::{Mode, PermissionManager};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> color_eyre::Result<()> {
|
||||
color_eyre::install()?;
|
||||
|
||||
// Note: This is a minimal example. In a real application, you would:
|
||||
// 1. Initialize a real LLM provider (e.g., OllamaClient)
|
||||
// 2. Configure the ChatOptions with your preferred model
|
||||
// 3. Set up appropriate permissions and tool context
|
||||
|
||||
println!("=== Streaming Agent Example ===\n");
|
||||
println!("This example demonstrates how to use the streaming agent loop API.");
|
||||
println!("To run with a real LLM provider, modify this example to:");
|
||||
println!(" 1. Create an LLM provider instance");
|
||||
println!(" 2. Set up permissions and tool context");
|
||||
println!(" 3. Call run_agent_loop_streaming with your prompt\n");
|
||||
|
||||
// Example code structure:
|
||||
println!("Example code:");
|
||||
println!("```rust");
|
||||
println!("// Create LLM provider");
|
||||
println!("let provider = OllamaClient::new(\"http://localhost:11434\");");
|
||||
println!();
|
||||
println!("// Set up permissions and context");
|
||||
println!("let perms = PermissionManager::new(Mode::Plan);");
|
||||
println!("let ctx = ToolContext::default();");
|
||||
println!();
|
||||
println!("// Create event channel");
|
||||
println!("let (tx, mut rx) = create_event_channel();");
|
||||
println!();
|
||||
println!("// Spawn agent loop");
|
||||
println!("let handle = tokio::spawn(async move {{");
|
||||
println!(" run_agent_loop_streaming(");
|
||||
println!(" &provider,");
|
||||
println!(" \"Your prompt here\",");
|
||||
println!(" &ChatOptions::default(),");
|
||||
println!(" &perms,");
|
||||
println!(" &ctx,");
|
||||
println!(" tx,");
|
||||
println!(" ).await");
|
||||
println!("}});");
|
||||
println!();
|
||||
println!("// Process events");
|
||||
println!("while let Some(event) = rx.recv().await {{");
|
||||
println!(" match event {{");
|
||||
println!(" AgentEvent::TextDelta(text) => {{");
|
||||
println!(" print!(\"{{text}}\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolStart {{ tool_name, .. }} => {{");
|
||||
println!(" println!(\"\\n[Executing tool: {{tool_name}}]\");");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolOutput {{ content, is_error, .. }} => {{");
|
||||
println!(" if is_error {{");
|
||||
println!(" eprintln!(\"Error: {{content}}\");");
|
||||
println!(" }} else {{");
|
||||
println!(" println!(\"Output: {{content}}\");");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::ToolEnd {{ success, .. }} => {{");
|
||||
println!(" println!(\"[Tool finished: {{}}]\", if success {{ \"success\" }} else {{ \"failed\" }});");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Done {{ final_response }} => {{");
|
||||
println!(" println!(\"\\n\\nFinal response: {{final_response}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" AgentEvent::Error(e) => {{");
|
||||
println!(" eprintln!(\"Error: {{e}}\");");
|
||||
println!(" break;");
|
||||
println!(" }}");
|
||||
println!(" }}");
|
||||
println!("}}");
|
||||
println!();
|
||||
println!("// Wait for completion");
|
||||
println!("let result = handle.await??;");
|
||||
println!("```");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
218
crates/core/agent/src/compact.rs
Normal file
218
crates/core/agent/src/compact.rs
Normal file
@@ -0,0 +1,218 @@
|
||||
//! Context compaction for long conversations
|
||||
//!
|
||||
//! When the conversation context grows too large, this module compacts
|
||||
//! earlier messages into a summary while preserving recent context.
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use llm_core::{ChatMessage, ChatOptions, LlmProvider};
|
||||
|
||||
/// Token limit threshold for triggering compaction
|
||||
const CONTEXT_LIMIT: usize = 180_000;
|
||||
|
||||
/// Threshold ratio at which to trigger compaction (90% of limit)
|
||||
const COMPACTION_THRESHOLD: f64 = 0.9;
|
||||
|
||||
/// Number of recent messages to preserve during compaction
|
||||
const PRESERVE_RECENT: usize = 10;
|
||||
|
||||
/// Token counter for estimating context size
|
||||
pub struct TokenCounter {
|
||||
chars_per_token: f64,
|
||||
}
|
||||
|
||||
impl Default for TokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter {
|
||||
pub fn new() -> Self {
|
||||
// Rough estimate: ~4 chars per token for English text
|
||||
Self { chars_per_token: 4.0 }
|
||||
}
|
||||
|
||||
/// Estimate token count for a message
|
||||
pub fn count_message(&self, message: &ChatMessage) -> usize {
|
||||
let content_len = message.content.as_ref().map(|c| c.len()).unwrap_or(0);
|
||||
// Add overhead for role, metadata
|
||||
let overhead = 10;
|
||||
((content_len as f64 / self.chars_per_token) as usize) + overhead
|
||||
}
|
||||
|
||||
/// Estimate total token count for all messages
|
||||
pub fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
messages.iter().map(|m| self.count_message(m)).sum()
|
||||
}
|
||||
|
||||
/// Check if context should be compacted
|
||||
pub fn should_compact(&self, messages: &[ChatMessage]) -> bool {
|
||||
let count = self.count_messages(messages);
|
||||
count > (CONTEXT_LIMIT as f64 * COMPACTION_THRESHOLD) as usize
|
||||
}
|
||||
}
|
||||
|
||||
/// Context compactor that summarizes conversation history
|
||||
pub struct Compactor {
|
||||
token_counter: TokenCounter,
|
||||
}
|
||||
|
||||
impl Default for Compactor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Compactor {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
token_counter: TokenCounter::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if messages need compaction
|
||||
pub fn needs_compaction(&self, messages: &[ChatMessage]) -> bool {
|
||||
self.token_counter.should_compact(messages)
|
||||
}
|
||||
|
||||
/// Compact messages by summarizing earlier conversation
|
||||
///
|
||||
/// Returns compacted messages with:
|
||||
/// - A system message containing the summary of earlier context
|
||||
/// - The most recent N messages preserved in full
|
||||
pub async fn compact<P: LlmProvider>(
|
||||
&self,
|
||||
provider: &P,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
) -> Result<Vec<ChatMessage>> {
|
||||
// If not enough messages to compact, return as-is
|
||||
if messages.len() <= PRESERVE_RECENT + 1 {
|
||||
return Ok(messages.to_vec());
|
||||
}
|
||||
|
||||
// Split into messages to summarize and messages to preserve
|
||||
let split_point = messages.len().saturating_sub(PRESERVE_RECENT);
|
||||
let to_summarize = &messages[..split_point];
|
||||
let to_preserve = &messages[split_point..];
|
||||
|
||||
// Generate summary of earlier messages
|
||||
let summary = self.summarize_messages(provider, to_summarize, options).await?;
|
||||
|
||||
// Build compacted message list
|
||||
let mut compacted = Vec::with_capacity(PRESERVE_RECENT + 1);
|
||||
|
||||
// Add system message with summary
|
||||
compacted.push(ChatMessage::system(format!(
|
||||
"## Earlier Conversation Summary\n\n{}\n\n---\n\n\
|
||||
The above summarizes the earlier part of this conversation. \
|
||||
Continue from the recent messages below.",
|
||||
summary
|
||||
)));
|
||||
|
||||
// Add preserved recent messages
|
||||
compacted.extend(to_preserve.iter().cloned());
|
||||
|
||||
Ok(compacted)
|
||||
}
|
||||
|
||||
/// Generate a summary of messages using the LLM
|
||||
async fn summarize_messages<P: LlmProvider>(
|
||||
&self,
|
||||
provider: &P,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
) -> Result<String> {
|
||||
// Format messages for summarization
|
||||
let mut context = String::new();
|
||||
for msg in messages {
|
||||
let role = &msg.role;
|
||||
let content = msg.content.as_deref().unwrap_or("");
|
||||
context.push_str(&format!("[{:?}]: {}\n\n", role, content));
|
||||
}
|
||||
|
||||
// Create summarization prompt
|
||||
let summary_prompt = format!(
|
||||
"Please provide a concise summary of the following conversation. \
|
||||
Focus on:\n\
|
||||
1. Key decisions made\n\
|
||||
2. Important files or code mentioned\n\
|
||||
3. Tasks completed and their outcomes\n\
|
||||
4. Any pending items or next steps discussed\n\n\
|
||||
Keep the summary informative but brief (under 500 words).\n\n\
|
||||
Conversation:\n{}\n\n\
|
||||
Summary:",
|
||||
context
|
||||
);
|
||||
|
||||
// Call LLM to generate summary
|
||||
let summary_options = ChatOptions {
|
||||
model: options.model.clone(),
|
||||
max_tokens: Some(1000),
|
||||
temperature: Some(0.3), // Lower temperature for more focused summary
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let summary_messages = vec![ChatMessage::user(&summary_prompt)];
|
||||
let mut stream = provider.chat_stream(&summary_messages, &summary_options, None).await?;
|
||||
|
||||
let mut summary = String::new();
|
||||
use futures_util::StreamExt;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
if let Ok(chunk) = chunk_result {
|
||||
if let Some(content) = &chunk.content {
|
||||
summary.push_str(content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(summary.trim().to_string())
|
||||
}
|
||||
|
||||
/// Get token counter for external use
|
||||
pub fn token_counter(&self) -> &TokenCounter {
|
||||
&self.token_counter
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_token_counter_estimate() {
|
||||
let counter = TokenCounter::new();
|
||||
let msg = ChatMessage::user("Hello, world!");
|
||||
let count = counter.count_message(&msg);
|
||||
// Should be approximately 13/4 + 10 overhead = 13
|
||||
assert!(count > 10);
|
||||
assert!(count < 20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_should_compact() {
|
||||
let counter = TokenCounter::new();
|
||||
|
||||
// Small message list shouldn't compact
|
||||
let small_messages: Vec<ChatMessage> = (0..10)
|
||||
.map(|i| ChatMessage::user(&format!("Message {}", i)))
|
||||
.collect();
|
||||
assert!(!counter.should_compact(&small_messages));
|
||||
|
||||
// Large message list should compact
|
||||
// Need ~162,000 tokens = ~648,000 chars (at 4 chars per token)
|
||||
let large_content = "x".repeat(700_000);
|
||||
let large_messages = vec![ChatMessage::user(&large_content)];
|
||||
assert!(counter.should_compact(&large_messages));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_compactor_needs_compaction() {
|
||||
let compactor = Compactor::new();
|
||||
|
||||
let small: Vec<ChatMessage> = (0..5)
|
||||
.map(|i| ChatMessage::user(&format!("Short message {}", i)))
|
||||
.collect();
|
||||
assert!(!compactor.needs_compaction(&small));
|
||||
}
|
||||
}
|
||||
557
crates/core/agent/src/git.rs
Normal file
557
crates/core/agent/src/git.rs
Normal file
@@ -0,0 +1,557 @@
|
||||
//! Git integration module for detecting repository state and validating git commands.
|
||||
//!
|
||||
//! This module provides functionality to:
|
||||
//! - Detect if the current directory is a git repository
|
||||
//! - Capture git repository state (branch, status, uncommitted changes)
|
||||
//! - Validate git commands for safety (read-only vs destructive operations)
|
||||
|
||||
use color_eyre::eyre::Result;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
/// Status of a file in the git working tree
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum GitFileStatus {
|
||||
/// File has been modified
|
||||
Modified { path: String },
|
||||
/// File has been added (staged)
|
||||
Added { path: String },
|
||||
/// File has been deleted
|
||||
Deleted { path: String },
|
||||
/// File has been renamed
|
||||
Renamed { from: String, to: String },
|
||||
/// File is untracked
|
||||
Untracked { path: String },
|
||||
}
|
||||
|
||||
impl GitFileStatus {
|
||||
/// Get the primary path associated with this status
|
||||
pub fn path(&self) -> &str {
|
||||
match self {
|
||||
Self::Modified { path } => path,
|
||||
Self::Added { path } => path,
|
||||
Self::Deleted { path } => path,
|
||||
Self::Renamed { to, .. } => to,
|
||||
Self::Untracked { path } => path,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Complete state of a git repository
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GitState {
|
||||
/// Whether the current directory is in a git repository
|
||||
pub is_git_repo: bool,
|
||||
/// Current branch name (None if not in a repo or detached HEAD)
|
||||
pub current_branch: Option<String>,
|
||||
/// Main branch name (main/master, None if not detected)
|
||||
pub main_branch: Option<String>,
|
||||
/// Status of files in the working tree
|
||||
pub status: Vec<GitFileStatus>,
|
||||
/// Whether there are any uncommitted changes
|
||||
pub has_uncommitted_changes: bool,
|
||||
/// Remote URL for the repository (None if no remote configured)
|
||||
pub remote_url: Option<String>,
|
||||
}
|
||||
|
||||
impl GitState {
|
||||
/// Create a default GitState for non-git directories
|
||||
pub fn not_a_repo() -> Self {
|
||||
Self {
|
||||
is_git_repo: false,
|
||||
current_branch: None,
|
||||
main_branch: None,
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the current git repository state
|
||||
///
|
||||
/// This function runs various git commands to gather information about the repository.
|
||||
/// If git is not available or the directory is not a git repo, returns a default state.
|
||||
pub fn detect_git_state(working_dir: &Path) -> Result<GitState> {
|
||||
// Check if this is a git repository
|
||||
let is_repo = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--git-dir")
|
||||
.current_dir(working_dir)
|
||||
.output()
|
||||
.map(|output| output.status.success())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_repo {
|
||||
return Ok(GitState::not_a_repo());
|
||||
}
|
||||
|
||||
// Get current branch
|
||||
let current_branch = get_current_branch(working_dir)?;
|
||||
|
||||
// Detect main branch (try main first, then master)
|
||||
let main_branch = detect_main_branch(working_dir)?;
|
||||
|
||||
// Get file status
|
||||
let status = get_git_status(working_dir)?;
|
||||
|
||||
// Check if there are uncommitted changes
|
||||
let has_uncommitted_changes = !status.is_empty();
|
||||
|
||||
// Get remote URL
|
||||
let remote_url = get_remote_url(working_dir)?;
|
||||
|
||||
Ok(GitState {
|
||||
is_git_repo: true,
|
||||
current_branch,
|
||||
main_branch,
|
||||
status,
|
||||
has_uncommitted_changes,
|
||||
remote_url,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current branch name
|
||||
fn get_current_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("rev-parse")
|
||||
.arg("--abbrev-ref")
|
||||
.arg("HEAD")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branch = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
// "HEAD" means detached HEAD state
|
||||
if branch == "HEAD" {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(branch))
|
||||
}
|
||||
}
|
||||
|
||||
/// Detect the main branch (main or master)
|
||||
fn detect_main_branch(working_dir: &Path) -> Result<Option<String>> {
|
||||
// Try to get all branches
|
||||
let output = Command::new("git")
|
||||
.arg("branch")
|
||||
.arg("-a")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let branches = String::from_utf8_lossy(&output.stdout);
|
||||
|
||||
// Check for main branch first (modern convention)
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "main" || trimmed.ends_with("/main")
|
||||
}) {
|
||||
return Ok(Some("main".to_string()));
|
||||
}
|
||||
|
||||
// Fall back to master
|
||||
if branches.lines().any(|line| {
|
||||
let trimmed = line.trim_start_matches('*').trim();
|
||||
trimmed == "master" || trimmed.ends_with("/master")
|
||||
}) {
|
||||
return Ok(Some("master".to_string()));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Get the git status for all files
|
||||
fn get_git_status(working_dir: &Path) -> Result<Vec<GitFileStatus>> {
|
||||
let output = Command::new("git")
|
||||
.arg("status")
|
||||
.arg("--porcelain")
|
||||
.arg("-z") // Null-terminated for better parsing
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let status_text = String::from_utf8_lossy(&output.stdout);
|
||||
let mut statuses = Vec::new();
|
||||
|
||||
// Parse porcelain format with null termination
|
||||
// Format: XY filename\0 (where X is staged status, Y is unstaged status)
|
||||
for entry in status_text.split('\0').filter(|s| !s.is_empty()) {
|
||||
if entry.len() < 3 {
|
||||
continue;
|
||||
}
|
||||
|
||||
let status_code = &entry[0..2];
|
||||
let path = entry[3..].to_string();
|
||||
|
||||
// Parse status codes
|
||||
match status_code {
|
||||
"M " | " M" | "MM" => {
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
"A " | " A" | "AM" => {
|
||||
statuses.push(GitFileStatus::Added { path });
|
||||
}
|
||||
"D " | " D" | "AD" => {
|
||||
statuses.push(GitFileStatus::Deleted { path });
|
||||
}
|
||||
"??" => {
|
||||
statuses.push(GitFileStatus::Untracked { path });
|
||||
}
|
||||
s if s.starts_with('R') => {
|
||||
// Renamed files have format "R old_name -> new_name"
|
||||
if let Some((from, to)) = path.split_once(" -> ") {
|
||||
statuses.push(GitFileStatus::Renamed {
|
||||
from: from.to_string(),
|
||||
to: to.to_string(),
|
||||
});
|
||||
} else {
|
||||
// Fallback if parsing fails
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown status code, treat as modified
|
||||
statuses.push(GitFileStatus::Modified { path });
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(statuses)
|
||||
}
|
||||
|
||||
/// Get the remote URL for the repository
|
||||
fn get_remote_url(working_dir: &Path) -> Result<Option<String>> {
|
||||
let output = Command::new("git")
|
||||
.arg("remote")
|
||||
.arg("get-url")
|
||||
.arg("origin")
|
||||
.current_dir(working_dir)
|
||||
.output()?;
|
||||
|
||||
if !output.status.success() {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let url = String::from_utf8_lossy(&output.stdout).trim().to_string();
|
||||
|
||||
if url.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(url))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is safe (read-only)
|
||||
///
|
||||
/// Safe commands include:
|
||||
/// - status, log, show, diff, branch (without -D)
|
||||
/// - remote (without add/remove)
|
||||
/// - config --get
|
||||
/// - rev-parse, ls-files, ls-tree
|
||||
pub fn is_safe_git_command(command: &str) -> bool {
|
||||
let parts: Vec<&str> = command.split_whitespace().collect();
|
||||
|
||||
if parts.is_empty() || parts[0] != "git" {
|
||||
return false;
|
||||
}
|
||||
|
||||
if parts.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
let subcommand = parts[1];
|
||||
|
||||
// List of read-only git commands
|
||||
match subcommand {
|
||||
"status" | "log" | "show" | "diff" | "blame" | "reflog" => true,
|
||||
"ls-files" | "ls-tree" | "ls-remote" => true,
|
||||
"rev-parse" | "rev-list" => true,
|
||||
"describe" | "tag" if !command.contains("-d") && !command.contains("--delete") => true,
|
||||
"branch" if !command.contains("-D") && !command.contains("-d") && !command.contains("-m") => true,
|
||||
"remote" if command.contains("get-url") || command.contains("-v") || command.contains("show") => true,
|
||||
"config" if command.contains("--get") || command.contains("--list") => true,
|
||||
"grep" | "shortlog" | "whatchanged" => true,
|
||||
"fetch" if !command.contains("--prune") => true,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a git command is destructive
|
||||
///
|
||||
/// Returns (is_destructive, warning_message) tuple.
|
||||
/// Destructive commands include:
|
||||
/// - push --force, reset --hard, clean -fd
|
||||
/// - rebase, amend, filter-branch
|
||||
/// - branch -D, tag -d
|
||||
pub fn is_destructive_git_command(command: &str) -> (bool, &'static str) {
|
||||
let cmd_lower = command.to_lowercase();
|
||||
|
||||
// Check for force push
|
||||
if cmd_lower.contains("push") && (cmd_lower.contains("--force") || cmd_lower.contains("-f")) {
|
||||
return (true, "Force push can overwrite remote history and affect other collaborators");
|
||||
}
|
||||
|
||||
// Check for hard reset
|
||||
if cmd_lower.contains("reset") && cmd_lower.contains("--hard") {
|
||||
return (true, "Hard reset will discard uncommitted changes permanently");
|
||||
}
|
||||
|
||||
// Check for git clean
|
||||
if cmd_lower.contains("clean") && (cmd_lower.contains("-f") || cmd_lower.contains("-d")) {
|
||||
return (true, "Git clean will permanently delete untracked files");
|
||||
}
|
||||
|
||||
// Check for rebase
|
||||
if cmd_lower.contains("rebase") {
|
||||
return (true, "Rebase rewrites commit history and can cause conflicts");
|
||||
}
|
||||
|
||||
// Check for amend
|
||||
if cmd_lower.contains("commit") && cmd_lower.contains("--amend") {
|
||||
return (true, "Amending rewrites the last commit and changes its hash");
|
||||
}
|
||||
|
||||
// Check for filter-branch or filter-repo
|
||||
if cmd_lower.contains("filter-branch") || cmd_lower.contains("filter-repo") {
|
||||
return (true, "Filter operations rewrite repository history");
|
||||
}
|
||||
|
||||
// Check for branch/tag deletion
|
||||
if (cmd_lower.contains("branch") && (cmd_lower.contains("-D") || cmd_lower.contains("-d")))
|
||||
|| (cmd_lower.contains("tag") && (cmd_lower.contains("-d") || cmd_lower.contains("--delete")))
|
||||
{
|
||||
return (true, "This will delete a branch or tag");
|
||||
}
|
||||
|
||||
// Check for reflog expire
|
||||
if cmd_lower.contains("reflog") && cmd_lower.contains("expire") {
|
||||
return (true, "Expiring reflog removes recovery points for lost commits");
|
||||
}
|
||||
|
||||
// Check for gc with aggressive or prune
|
||||
if cmd_lower.contains("gc") && (cmd_lower.contains("--aggressive") || cmd_lower.contains("--prune")) {
|
||||
return (true, "Aggressive garbage collection can make recovery difficult");
|
||||
}
|
||||
|
||||
(false, "")
|
||||
}
|
||||
|
||||
/// Format git state for human-readable display
|
||||
///
|
||||
/// Example output:
|
||||
/// ```text
|
||||
/// Git Repository: yes
|
||||
/// Current branch: feature-branch
|
||||
/// Main branch: main
|
||||
/// Status: 3 modified, 1 untracked
|
||||
/// Remote: https://github.com/user/repo.git
|
||||
/// ```
|
||||
pub fn format_git_status(state: &GitState) -> String {
|
||||
if !state.is_git_repo {
|
||||
return "Not a git repository".to_string();
|
||||
}
|
||||
|
||||
let mut lines = Vec::new();
|
||||
|
||||
lines.push("Git Repository: yes".to_string());
|
||||
|
||||
if let Some(branch) = &state.current_branch {
|
||||
lines.push(format!("Current branch: {}", branch));
|
||||
} else {
|
||||
lines.push("Current branch: (detached HEAD)".to_string());
|
||||
}
|
||||
|
||||
if let Some(main) = &state.main_branch {
|
||||
lines.push(format!("Main branch: {}", main));
|
||||
}
|
||||
|
||||
// Summarize status
|
||||
if state.status.is_empty() {
|
||||
lines.push("Status: clean working tree".to_string());
|
||||
} else {
|
||||
let mut modified = 0;
|
||||
let mut added = 0;
|
||||
let mut deleted = 0;
|
||||
let mut renamed = 0;
|
||||
let mut untracked = 0;
|
||||
|
||||
for status in &state.status {
|
||||
match status {
|
||||
GitFileStatus::Modified { .. } => modified += 1,
|
||||
GitFileStatus::Added { .. } => added += 1,
|
||||
GitFileStatus::Deleted { .. } => deleted += 1,
|
||||
GitFileStatus::Renamed { .. } => renamed += 1,
|
||||
GitFileStatus::Untracked { .. } => untracked += 1,
|
||||
}
|
||||
}
|
||||
|
||||
let mut status_parts = Vec::new();
|
||||
if modified > 0 {
|
||||
status_parts.push(format!("{} modified", modified));
|
||||
}
|
||||
if added > 0 {
|
||||
status_parts.push(format!("{} added", added));
|
||||
}
|
||||
if deleted > 0 {
|
||||
status_parts.push(format!("{} deleted", deleted));
|
||||
}
|
||||
if renamed > 0 {
|
||||
status_parts.push(format!("{} renamed", renamed));
|
||||
}
|
||||
if untracked > 0 {
|
||||
status_parts.push(format!("{} untracked", untracked));
|
||||
}
|
||||
|
||||
lines.push(format!("Status: {}", status_parts.join(", ")));
|
||||
}
|
||||
|
||||
if let Some(url) = &state.remote_url {
|
||||
lines.push(format!("Remote: {}", url));
|
||||
} else {
|
||||
lines.push("Remote: (none)".to_string());
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_is_safe_git_command() {
|
||||
// Safe commands
|
||||
assert!(is_safe_git_command("git status"));
|
||||
assert!(is_safe_git_command("git log --oneline"));
|
||||
assert!(is_safe_git_command("git diff HEAD"));
|
||||
assert!(is_safe_git_command("git branch -v"));
|
||||
assert!(is_safe_git_command("git remote -v"));
|
||||
assert!(is_safe_git_command("git config --get user.name"));
|
||||
|
||||
// Unsafe commands
|
||||
assert!(!is_safe_git_command("git commit -m test"));
|
||||
assert!(!is_safe_git_command("git push origin main"));
|
||||
assert!(!is_safe_git_command("git branch -D feature"));
|
||||
assert!(!is_safe_git_command("git remote add origin url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_destructive_git_command() {
|
||||
// Destructive commands
|
||||
let (is_dest, msg) = is_destructive_git_command("git push --force origin main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Force push"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git reset --hard HEAD~1");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Hard reset"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git clean -fd");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("clean"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git rebase main");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Rebase"));
|
||||
|
||||
let (is_dest, msg) = is_destructive_git_command("git commit --amend");
|
||||
assert!(is_dest);
|
||||
assert!(msg.contains("Amending"));
|
||||
|
||||
// Non-destructive commands
|
||||
let (is_dest, _) = is_destructive_git_command("git status");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git log");
|
||||
assert!(!is_dest);
|
||||
|
||||
let (is_dest, _) = is_destructive_git_command("git diff");
|
||||
assert!(!is_dest);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_state_not_a_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
assert!(!state.is_git_repo);
|
||||
assert!(state.current_branch.is_none());
|
||||
assert!(state.main_branch.is_none());
|
||||
assert!(state.status.is_empty());
|
||||
assert!(!state.has_uncommitted_changes);
|
||||
assert!(state.remote_url.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_git_file_status_path() {
|
||||
let status = GitFileStatus::Modified {
|
||||
path: "test.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "test.rs");
|
||||
|
||||
let status = GitFileStatus::Renamed {
|
||||
from: "old.rs".to_string(),
|
||||
to: "new.rs".to_string(),
|
||||
};
|
||||
assert_eq!(status.path(), "new.rs");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_not_repo() {
|
||||
let state = GitState::not_a_repo();
|
||||
let formatted = format_git_status(&state);
|
||||
assert_eq!(formatted, "Not a git repository");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_clean() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("main".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: Vec::new(),
|
||||
has_uncommitted_changes: false,
|
||||
remote_url: Some("https://github.com/user/repo.git".to_string()),
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("Git Repository: yes"));
|
||||
assert!(formatted.contains("Current branch: main"));
|
||||
assert!(formatted.contains("clean working tree"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_git_status_with_changes() {
|
||||
let state = GitState {
|
||||
is_git_repo: true,
|
||||
current_branch: Some("feature".to_string()),
|
||||
main_branch: Some("main".to_string()),
|
||||
status: vec![
|
||||
GitFileStatus::Modified {
|
||||
path: "file1.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Modified {
|
||||
path: "file2.rs".to_string(),
|
||||
},
|
||||
GitFileStatus::Untracked {
|
||||
path: "new.rs".to_string(),
|
||||
},
|
||||
],
|
||||
has_uncommitted_changes: true,
|
||||
remote_url: None,
|
||||
};
|
||||
|
||||
let formatted = format_git_status(&state);
|
||||
assert!(formatted.contains("2 modified"));
|
||||
assert!(formatted.contains("1 untracked"));
|
||||
}
|
||||
}
|
||||
1130
crates/core/agent/src/lib.rs
Normal file
1130
crates/core/agent/src/lib.rs
Normal file
File diff suppressed because it is too large
Load Diff
295
crates/core/agent/src/session.rs
Normal file
295
crates/core/agent/src/session.rs
Normal file
@@ -0,0 +1,295 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionStats {
|
||||
pub start_time: SystemTime,
|
||||
pub total_messages: usize,
|
||||
pub total_tool_calls: usize,
|
||||
pub total_duration: Duration,
|
||||
pub estimated_tokens: usize,
|
||||
}
|
||||
|
||||
impl SessionStats {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
start_time: SystemTime::now(),
|
||||
total_messages: 0,
|
||||
total_tool_calls: 0,
|
||||
total_duration: Duration::ZERO,
|
||||
estimated_tokens: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn record_message(&mut self, tokens: usize, duration: Duration) {
|
||||
self.total_messages += 1;
|
||||
self.estimated_tokens += tokens;
|
||||
self.total_duration += duration;
|
||||
}
|
||||
|
||||
pub fn record_tool_call(&mut self) {
|
||||
self.total_tool_calls += 1;
|
||||
}
|
||||
|
||||
pub fn format_duration(d: Duration) -> String {
|
||||
let secs = d.as_secs();
|
||||
if secs < 60 {
|
||||
format!("{}s", secs)
|
||||
} else if secs < 3600 {
|
||||
format!("{}m {}s", secs / 60, secs % 60)
|
||||
} else {
|
||||
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionStats {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SessionHistory {
|
||||
pub user_prompts: Vec<String>,
|
||||
pub assistant_responses: Vec<String>,
|
||||
pub tool_calls: Vec<ToolCallRecord>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallRecord {
|
||||
pub tool_name: String,
|
||||
pub arguments: String,
|
||||
pub result: String,
|
||||
pub success: bool,
|
||||
}
|
||||
|
||||
impl SessionHistory {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
user_prompts: Vec::new(),
|
||||
assistant_responses: Vec::new(),
|
||||
tool_calls: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_user_message(&mut self, message: String) {
|
||||
self.user_prompts.push(message);
|
||||
}
|
||||
|
||||
pub fn add_assistant_message(&mut self, message: String) {
|
||||
self.assistant_responses.push(message);
|
||||
}
|
||||
|
||||
pub fn add_tool_call(&mut self, record: ToolCallRecord) {
|
||||
self.tool_calls.push(record);
|
||||
}
|
||||
|
||||
pub fn clear(&mut self) {
|
||||
self.user_prompts.clear();
|
||||
self.assistant_responses.clear();
|
||||
self.tool_calls.clear();
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SessionHistory {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a file modification with before/after content
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FileDiff {
|
||||
pub path: PathBuf,
|
||||
pub before: String,
|
||||
pub after: String,
|
||||
pub timestamp: SystemTime,
|
||||
}
|
||||
|
||||
impl FileDiff {
|
||||
/// Create a new file diff
|
||||
pub fn new(path: PathBuf, before: String, after: String) -> Self {
|
||||
Self {
|
||||
path,
|
||||
before,
|
||||
after,
|
||||
timestamp: SystemTime::now(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A checkpoint captures the state of a session at a point in time
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Checkpoint {
|
||||
pub id: String,
|
||||
pub timestamp: SystemTime,
|
||||
pub stats: SessionStats,
|
||||
pub user_prompts: Vec<String>,
|
||||
pub assistant_responses: Vec<String>,
|
||||
pub tool_calls: Vec<ToolCallRecord>,
|
||||
pub file_diffs: Vec<FileDiff>,
|
||||
}
|
||||
|
||||
impl Checkpoint {
|
||||
/// Create a new checkpoint from current session state
|
||||
pub fn new(
|
||||
id: String,
|
||||
stats: SessionStats,
|
||||
history: &SessionHistory,
|
||||
file_diffs: Vec<FileDiff>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id,
|
||||
timestamp: SystemTime::now(),
|
||||
stats,
|
||||
user_prompts: history.user_prompts.clone(),
|
||||
assistant_responses: history.assistant_responses.clone(),
|
||||
tool_calls: history.tool_calls.clone(),
|
||||
file_diffs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Save checkpoint to disk
|
||||
pub fn save(&self, checkpoint_dir: &Path) -> Result<()> {
|
||||
fs::create_dir_all(checkpoint_dir)?;
|
||||
let path = checkpoint_dir.join(format!("{}.json", self.id));
|
||||
let content = serde_json::to_string_pretty(self)?;
|
||||
fs::write(path, content)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load checkpoint from disk
|
||||
pub fn load(checkpoint_dir: &Path, id: &str) -> Result<Self> {
|
||||
let path = checkpoint_dir.join(format!("{}.json", id));
|
||||
let content = fs::read_to_string(&path)
|
||||
.map_err(|e| eyre!("Failed to read checkpoint: {}", e))?;
|
||||
let checkpoint: Checkpoint = serde_json::from_str(&content)
|
||||
.map_err(|e| eyre!("Failed to parse checkpoint: {}", e))?;
|
||||
Ok(checkpoint)
|
||||
}
|
||||
|
||||
/// List all available checkpoints in a directory
|
||||
pub fn list(checkpoint_dir: &Path) -> Result<Vec<String>> {
|
||||
if !checkpoint_dir.exists() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let mut checkpoints = Vec::new();
|
||||
for entry in fs::read_dir(checkpoint_dir)? {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
|
||||
checkpoints.push(stem.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by checkpoint ID (which includes timestamp)
|
||||
checkpoints.sort();
|
||||
Ok(checkpoints)
|
||||
}
|
||||
}
|
||||
|
||||
/// Session checkpoint manager
|
||||
pub struct CheckpointManager {
|
||||
checkpoint_dir: PathBuf,
|
||||
file_snapshots: HashMap<PathBuf, String>,
|
||||
}
|
||||
|
||||
impl CheckpointManager {
|
||||
/// Create a new checkpoint manager
|
||||
pub fn new(checkpoint_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
checkpoint_dir,
|
||||
file_snapshots: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Snapshot a file's current content before modification
|
||||
pub fn snapshot_file(&mut self, path: &Path) -> Result<()> {
|
||||
if !self.file_snapshots.contains_key(path) {
|
||||
let content = fs::read_to_string(path).unwrap_or_default();
|
||||
self.file_snapshots.insert(path.to_path_buf(), content);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Create a file diff after modification
|
||||
pub fn create_diff(&self, path: &Path) -> Result<Option<FileDiff>> {
|
||||
if let Some(before) = self.file_snapshots.get(path) {
|
||||
let after = fs::read_to_string(path).unwrap_or_default();
|
||||
if before != &after {
|
||||
Ok(Some(FileDiff::new(
|
||||
path.to_path_buf(),
|
||||
before.clone(),
|
||||
after,
|
||||
)))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get all file diffs since last checkpoint
|
||||
pub fn get_all_diffs(&self) -> Result<Vec<FileDiff>> {
|
||||
let mut diffs = Vec::new();
|
||||
for (path, before) in &self.file_snapshots {
|
||||
let after = fs::read_to_string(path).unwrap_or_default();
|
||||
if before != &after {
|
||||
diffs.push(FileDiff::new(path.clone(), before.clone(), after));
|
||||
}
|
||||
}
|
||||
Ok(diffs)
|
||||
}
|
||||
|
||||
/// Clear file snapshots
|
||||
pub fn clear_snapshots(&mut self) {
|
||||
self.file_snapshots.clear();
|
||||
}
|
||||
|
||||
/// Save a checkpoint
|
||||
pub fn save_checkpoint(
|
||||
&mut self,
|
||||
id: String,
|
||||
stats: SessionStats,
|
||||
history: &SessionHistory,
|
||||
) -> Result<Checkpoint> {
|
||||
let file_diffs = self.get_all_diffs()?;
|
||||
let checkpoint = Checkpoint::new(id, stats, history, file_diffs);
|
||||
checkpoint.save(&self.checkpoint_dir)?;
|
||||
self.clear_snapshots();
|
||||
Ok(checkpoint)
|
||||
}
|
||||
|
||||
/// Load a checkpoint
|
||||
pub fn load_checkpoint(&self, id: &str) -> Result<Checkpoint> {
|
||||
Checkpoint::load(&self.checkpoint_dir, id)
|
||||
}
|
||||
|
||||
/// List all checkpoints
|
||||
pub fn list_checkpoints(&self) -> Result<Vec<String>> {
|
||||
Checkpoint::list(&self.checkpoint_dir)
|
||||
}
|
||||
|
||||
/// Rewind to a checkpoint by restoring file contents
|
||||
pub fn rewind_to(&self, checkpoint_id: &str) -> Result<Vec<PathBuf>> {
|
||||
let checkpoint = self.load_checkpoint(checkpoint_id)?;
|
||||
let mut restored_files = Vec::new();
|
||||
|
||||
// Restore files from diffs (revert to 'before' state)
|
||||
for diff in &checkpoint.file_diffs {
|
||||
fs::write(&diff.path, &diff.before)?;
|
||||
restored_files.push(diff.path.clone());
|
||||
}
|
||||
|
||||
Ok(restored_files)
|
||||
}
|
||||
}
|
||||
266
crates/core/agent/src/system_prompt.rs
Normal file
266
crates/core/agent/src/system_prompt.rs
Normal file
@@ -0,0 +1,266 @@
|
||||
//! System Prompt Management
|
||||
//!
|
||||
//! Composes system prompts from multiple sources for agent sessions.
|
||||
|
||||
use std::path::Path;
|
||||
|
||||
/// Builder for composing system prompts
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct SystemPromptBuilder {
|
||||
sections: Vec<PromptSection>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PromptSection {
|
||||
name: String,
|
||||
content: String,
|
||||
priority: i32, // Lower = earlier in prompt
|
||||
}
|
||||
|
||||
impl SystemPromptBuilder {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Add the base agent prompt
|
||||
pub fn with_base_prompt(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "base".to_string(),
|
||||
content: content.into(),
|
||||
priority: 0,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add tool usage instructions
|
||||
pub fn with_tool_instructions(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "tools".to_string(),
|
||||
content: content.into(),
|
||||
priority: 10,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Load and add project instructions from CLAUDE.md or .owlen.md
|
||||
pub fn with_project_instructions(mut self, project_root: &Path) -> Self {
|
||||
// Try CLAUDE.md first (Claude Code compatibility)
|
||||
let claude_md = project_root.join("CLAUDE.md");
|
||||
if claude_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&claude_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
return self;
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to .owlen.md
|
||||
let owlen_md = project_root.join(".owlen.md");
|
||||
if owlen_md.exists() {
|
||||
if let Ok(content) = std::fs::read_to_string(&owlen_md) {
|
||||
self.sections.push(PromptSection {
|
||||
name: "project".to_string(),
|
||||
content: format!("# Project Instructions\n\n{}", content),
|
||||
priority: 20,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Add skill content
|
||||
pub fn with_skill(mut self, skill_name: &str, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: format!("skill:{}", skill_name),
|
||||
content: content.into(),
|
||||
priority: 30,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add hook-injected content (from SessionStart hooks)
|
||||
pub fn with_hook_injection(mut self, content: impl Into<String>) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: "hook".to_string(),
|
||||
content: content.into(),
|
||||
priority: 40,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Add custom section
|
||||
pub fn with_section(mut self, name: impl Into<String>, content: impl Into<String>, priority: i32) -> Self {
|
||||
self.sections.push(PromptSection {
|
||||
name: name.into(),
|
||||
content: content.into(),
|
||||
priority,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the final system prompt
|
||||
pub fn build(mut self) -> String {
|
||||
// Sort by priority
|
||||
self.sections.sort_by_key(|s| s.priority);
|
||||
|
||||
// Join sections with separators
|
||||
self.sections
|
||||
.iter()
|
||||
.map(|s| s.content.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n---\n\n")
|
||||
}
|
||||
|
||||
/// Check if any content has been added
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.sections.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
/// Default base prompt for Owlen agent
|
||||
pub fn default_base_prompt() -> &'static str {
|
||||
r#"You are Owlen, an AI assistant that helps with software engineering tasks.
|
||||
|
||||
You have access to tools for reading files, writing code, running commands, and searching the web.
|
||||
|
||||
## Guidelines
|
||||
|
||||
1. Be direct and concise in your responses
|
||||
2. Use tools to gather information before making changes
|
||||
3. Explain your reasoning when making decisions
|
||||
4. Ask for clarification when requirements are unclear
|
||||
5. Prefer editing existing files over creating new ones
|
||||
|
||||
## Tool Usage
|
||||
|
||||
- Use `read` to examine file contents before editing
|
||||
- Use `glob` and `grep` to find relevant files
|
||||
- Use `edit` for precise changes, `write` for new files
|
||||
- Use `bash` for running tests and commands
|
||||
- Use `web_search` for current information"#
|
||||
}
|
||||
|
||||
/// Generate tool instructions based on available tools
|
||||
pub fn generate_tool_instructions(tool_names: &[&str]) -> String {
|
||||
let mut instructions = String::from("## Available Tools\n\n");
|
||||
|
||||
for name in tool_names {
|
||||
let desc = match *name {
|
||||
"read" => "Read file contents",
|
||||
"write" => "Create or overwrite a file",
|
||||
"edit" => "Edit a file by replacing text",
|
||||
"multi_edit" => "Apply multiple edits atomically",
|
||||
"glob" => "Find files by pattern",
|
||||
"grep" => "Search file contents",
|
||||
"ls" => "List directory contents",
|
||||
"bash" => "Execute shell commands",
|
||||
"web_search" => "Search the web",
|
||||
"web_fetch" => "Fetch a URL",
|
||||
"todo_write" => "Update task list",
|
||||
"ask_user" => "Ask user a question",
|
||||
_ => continue,
|
||||
};
|
||||
instructions.push_str(&format!("- `{}`: {}\n", name, desc));
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_builder() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("You are helpful")
|
||||
.with_tool_instructions("Use tools wisely")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("You are helpful"));
|
||||
assert!(prompt.contains("Use tools wisely"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_priority_ordering() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("last", "Third", 100)
|
||||
.with_section("first", "First", 0)
|
||||
.with_section("middle", "Second", 50)
|
||||
.build();
|
||||
|
||||
let first_pos = prompt.find("First").unwrap();
|
||||
let second_pos = prompt.find("Second").unwrap();
|
||||
let third_pos = prompt.find("Third").unwrap();
|
||||
|
||||
assert!(first_pos < second_pos);
|
||||
assert!(second_pos < third_pos);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_base_prompt() {
|
||||
let prompt = default_base_prompt();
|
||||
assert!(prompt.contains("Owlen"));
|
||||
assert!(prompt.contains("Guidelines"));
|
||||
assert!(prompt.contains("Tool Usage"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_generate_tool_instructions() {
|
||||
let tools = vec!["read", "write", "edit", "bash"];
|
||||
let instructions = generate_tool_instructions(&tools);
|
||||
|
||||
assert!(instructions.contains("Available Tools"));
|
||||
assert!(instructions.contains("read"));
|
||||
assert!(instructions.contains("write"));
|
||||
assert!(instructions.contains("edit"));
|
||||
assert!(instructions.contains("bash"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_builder_empty() {
|
||||
let builder = SystemPromptBuilder::new();
|
||||
assert!(builder.is_empty());
|
||||
|
||||
let builder = builder.with_base_prompt("test");
|
||||
assert!(!builder.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skill_section() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_skill("rust", "Rust expertise")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Rust expertise"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hook_injection() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_base_prompt("Base")
|
||||
.with_hook_injection("Additional context from hook")
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("Base"));
|
||||
assert!(prompt.contains("Additional context from hook"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_separator_between_sections() {
|
||||
let prompt = SystemPromptBuilder::new()
|
||||
.with_section("first", "First section", 0)
|
||||
.with_section("second", "Second section", 10)
|
||||
.build();
|
||||
|
||||
assert!(prompt.contains("---"));
|
||||
assert!(prompt.contains("First section"));
|
||||
assert!(prompt.contains("Second section"));
|
||||
}
|
||||
}
|
||||
210
crates/core/agent/tests/checkpoint.rs
Normal file
210
crates/core/agent/tests/checkpoint.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
use agent_core::{Checkpoint, CheckpointManager, FileDiff, SessionHistory, SessionStats};
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_save_and_load() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
let stats = SessionStats::new();
|
||||
let mut history = SessionHistory::new();
|
||||
history.add_user_message("Hello".to_string());
|
||||
history.add_assistant_message("Hi there!".to_string());
|
||||
|
||||
let file_diffs = vec![FileDiff::new(
|
||||
PathBuf::from("test.txt"),
|
||||
"before".to_string(),
|
||||
"after".to_string(),
|
||||
)];
|
||||
|
||||
let checkpoint = Checkpoint::new(
|
||||
"test-checkpoint".to_string(),
|
||||
stats.clone(),
|
||||
&history,
|
||||
file_diffs,
|
||||
);
|
||||
|
||||
// Save checkpoint
|
||||
checkpoint.save(&checkpoint_dir).unwrap();
|
||||
|
||||
// Load checkpoint
|
||||
let loaded = Checkpoint::load(&checkpoint_dir, "test-checkpoint").unwrap();
|
||||
|
||||
assert_eq!(loaded.id, "test-checkpoint");
|
||||
assert_eq!(loaded.user_prompts, vec!["Hello"]);
|
||||
assert_eq!(loaded.assistant_responses, vec!["Hi there!"]);
|
||||
assert_eq!(loaded.file_diffs.len(), 1);
|
||||
assert_eq!(loaded.file_diffs[0].path, PathBuf::from("test.txt"));
|
||||
assert_eq!(loaded.file_diffs[0].before, "before");
|
||||
assert_eq!(loaded.file_diffs[0].after, "after");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_list() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().to_path_buf();
|
||||
|
||||
// Create a few checkpoints
|
||||
for i in 1..=3 {
|
||||
let checkpoint = Checkpoint::new(
|
||||
format!("checkpoint-{}", i),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
vec![],
|
||||
);
|
||||
checkpoint.save(&checkpoint_dir).unwrap();
|
||||
}
|
||||
|
||||
let checkpoints = Checkpoint::list(&checkpoint_dir).unwrap();
|
||||
assert_eq!(checkpoints.len(), 3);
|
||||
assert!(checkpoints.contains(&"checkpoint-1".to_string()));
|
||||
assert!(checkpoints.contains(&"checkpoint-2".to_string()));
|
||||
assert!(checkpoints.contains(&"checkpoint-3".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_snapshot_and_diff() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create initial file content
|
||||
fs::write(&test_file, "initial content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Modify the file
|
||||
fs::write(&test_file, "modified content").unwrap();
|
||||
|
||||
// Create a diff
|
||||
let diff = manager.create_diff(&test_file).unwrap();
|
||||
assert!(diff.is_some());
|
||||
|
||||
let diff = diff.unwrap();
|
||||
assert_eq!(diff.path, test_file);
|
||||
assert_eq!(diff.before, "initial content");
|
||||
assert_eq!(diff.after, "modified content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_save_and_restore() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create initial file content
|
||||
fs::write(&test_file, "initial content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Modify the file
|
||||
fs::write(&test_file, "modified content").unwrap();
|
||||
|
||||
// Save checkpoint
|
||||
let mut history = SessionHistory::new();
|
||||
history.add_user_message("test".to_string());
|
||||
let checkpoint = manager
|
||||
.save_checkpoint("test-checkpoint".to_string(), SessionStats::new(), &history)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 1);
|
||||
assert_eq!(checkpoint.file_diffs[0].before, "initial content");
|
||||
assert_eq!(checkpoint.file_diffs[0].after, "modified content");
|
||||
|
||||
// Modify file again
|
||||
fs::write(&test_file, "final content").unwrap();
|
||||
assert_eq!(fs::read_to_string(&test_file).unwrap(), "final content");
|
||||
|
||||
// Rewind to checkpoint
|
||||
let restored_files = manager.rewind_to("test-checkpoint").unwrap();
|
||||
assert_eq!(restored_files.len(), 1);
|
||||
assert_eq!(restored_files[0], test_file);
|
||||
|
||||
// File should be reverted to initial content (before the checkpoint)
|
||||
assert_eq!(fs::read_to_string(&test_file).unwrap(), "initial content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_manager_multiple_files() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file1 = temp_dir.path().join("file1.txt");
|
||||
let test_file2 = temp_dir.path().join("file2.txt");
|
||||
|
||||
// Create initial files
|
||||
fs::write(&test_file1, "file1 initial").unwrap();
|
||||
fs::write(&test_file2, "file2 initial").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot both files
|
||||
manager.snapshot_file(&test_file1).unwrap();
|
||||
manager.snapshot_file(&test_file2).unwrap();
|
||||
|
||||
// Modify both files
|
||||
fs::write(&test_file1, "file1 modified").unwrap();
|
||||
fs::write(&test_file2, "file2 modified").unwrap();
|
||||
|
||||
// Save checkpoint
|
||||
let checkpoint = manager
|
||||
.save_checkpoint(
|
||||
"multi-file-checkpoint".to_string(),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 2);
|
||||
|
||||
// Modify files again
|
||||
fs::write(&test_file1, "file1 final").unwrap();
|
||||
fs::write(&test_file2, "file2 final").unwrap();
|
||||
|
||||
// Rewind
|
||||
let restored_files = manager.rewind_to("multi-file-checkpoint").unwrap();
|
||||
assert_eq!(restored_files.len(), 2);
|
||||
|
||||
// Both files should be reverted
|
||||
assert_eq!(fs::read_to_string(&test_file1).unwrap(), "file1 initial");
|
||||
assert_eq!(fs::read_to_string(&test_file2).unwrap(), "file2 initial");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_checkpoint_no_changes() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let checkpoint_dir = temp_dir.path().join("checkpoints");
|
||||
let test_file = temp_dir.path().join("test.txt");
|
||||
|
||||
// Create file
|
||||
fs::write(&test_file, "content").unwrap();
|
||||
|
||||
let mut manager = CheckpointManager::new(checkpoint_dir.clone());
|
||||
|
||||
// Snapshot the file
|
||||
manager.snapshot_file(&test_file).unwrap();
|
||||
|
||||
// Don't modify the file
|
||||
|
||||
// Create diff - should be None because nothing changed
|
||||
let diff = manager.create_diff(&test_file).unwrap();
|
||||
assert!(diff.is_none());
|
||||
|
||||
// Save checkpoint - should have no diffs
|
||||
let checkpoint = manager
|
||||
.save_checkpoint(
|
||||
"no-change-checkpoint".to_string(),
|
||||
SessionStats::new(),
|
||||
&SessionHistory::new(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(checkpoint.file_diffs.len(), 0);
|
||||
}
|
||||
276
crates/core/agent/tests/streaming.rs
Normal file
276
crates/core/agent/tests/streaming.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
use agent_core::{create_event_channel, run_agent_loop_streaming, AgentEvent, ToolContext};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::stream;
|
||||
use llm_core::{
|
||||
ChatMessage, ChatOptions, LlmError, StreamChunk, LlmProvider, Tool, ToolCallDelta,
|
||||
};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use std::pin::Pin;
|
||||
|
||||
/// Mock LLM provider for testing streaming
|
||||
struct MockStreamingProvider {
|
||||
responses: Vec<MockResponse>,
|
||||
}
|
||||
|
||||
enum MockResponse {
|
||||
/// Text-only response (no tool calls)
|
||||
Text(Vec<String>), // Chunks of text
|
||||
/// Tool call response
|
||||
ToolCall {
|
||||
text_chunks: Vec<String>,
|
||||
tool_id: String,
|
||||
tool_name: String,
|
||||
tool_args: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for MockStreamingProvider {
|
||||
fn name(&self) -> &str {
|
||||
"mock"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
"mock-model"
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_options: &ChatOptions,
|
||||
_tools: Option<&[Tool]>,
|
||||
) -> Result<Pin<Box<dyn futures_util::Stream<Item = Result<StreamChunk, LlmError>> + Send>>, LlmError> {
|
||||
// Determine which response to use based on message count
|
||||
let response_idx = (messages.len() / 2).min(self.responses.len() - 1);
|
||||
let response = &self.responses[response_idx];
|
||||
|
||||
let chunks: Vec<Result<StreamChunk, LlmError>> = match response {
|
||||
MockResponse::Text(text_chunks) => text_chunks
|
||||
.iter()
|
||||
.map(|text| {
|
||||
Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
MockResponse::ToolCall {
|
||||
text_chunks,
|
||||
tool_id,
|
||||
tool_name,
|
||||
tool_args,
|
||||
} => {
|
||||
let mut result = vec![];
|
||||
|
||||
// First emit text chunks
|
||||
for text in text_chunks {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: Some(text.clone()),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Then emit tool call in chunks
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: Some(tool_id.clone()),
|
||||
function_name: Some(tool_name.clone()),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
|
||||
// Emit args in chunks
|
||||
for chunk in tool_args.chars().collect::<Vec<_>>().chunks(5) {
|
||||
result.push(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index: 0,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(chunk.iter().collect()),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
};
|
||||
|
||||
Ok(Box::pin(stream::iter(chunks)))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_text_only() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![MockResponse::Text(vec![
|
||||
"Hello".to_string(),
|
||||
" ".to_string(),
|
||||
"world".to_string(),
|
||||
"!".to_string(),
|
||||
])],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Say hello",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut done_response = None;
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::Done { final_response } => {
|
||||
done_response = Some(final_response);
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify events
|
||||
assert_eq!(text_deltas, vec!["Hello", " ", "world", "!"]);
|
||||
assert_eq!(done_response, Some("Hello world!".to_string()));
|
||||
assert_eq!(result.unwrap(), "Hello world!");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_tool_call() {
|
||||
let provider = MockStreamingProvider {
|
||||
responses: vec![
|
||||
MockResponse::ToolCall {
|
||||
text_chunks: vec!["Let me ".to_string(), "check...".to_string()],
|
||||
tool_id: "call_123".to_string(),
|
||||
tool_name: "glob".to_string(),
|
||||
tool_args: r#"{"pattern":"*.rs"}"#.to_string(),
|
||||
},
|
||||
MockResponse::Text(vec!["Found ".to_string(), "the files!".to_string()]),
|
||||
],
|
||||
};
|
||||
|
||||
let perms = PermissionManager::new(Mode::Plan);
|
||||
let ctx = ToolContext::default();
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Spawn the agent loop
|
||||
let handle = tokio::spawn(async move {
|
||||
run_agent_loop_streaming(
|
||||
&provider,
|
||||
"Find Rust files",
|
||||
&ChatOptions::default(),
|
||||
&perms,
|
||||
&ctx,
|
||||
tx,
|
||||
)
|
||||
.await
|
||||
});
|
||||
|
||||
// Collect events
|
||||
let mut text_deltas = vec![];
|
||||
let mut tool_starts = vec![];
|
||||
let mut tool_outputs = vec![];
|
||||
let mut tool_ends = vec![];
|
||||
|
||||
while let Some(event) = rx.recv().await {
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => {
|
||||
text_deltas.push(text);
|
||||
}
|
||||
AgentEvent::ToolStart {
|
||||
tool_name,
|
||||
tool_id,
|
||||
} => {
|
||||
tool_starts.push((tool_name, tool_id));
|
||||
}
|
||||
AgentEvent::ToolOutput {
|
||||
tool_id,
|
||||
content,
|
||||
is_error,
|
||||
} => {
|
||||
tool_outputs.push((tool_id, content, is_error));
|
||||
}
|
||||
AgentEvent::ToolEnd { tool_id, success } => {
|
||||
tool_ends.push((tool_id, success));
|
||||
}
|
||||
AgentEvent::Done { .. } => {
|
||||
break;
|
||||
}
|
||||
AgentEvent::Error(e) => {
|
||||
panic!("Unexpected error: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Wait for agent loop to complete
|
||||
let result = handle.await.unwrap();
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Verify we got text deltas from both responses
|
||||
assert!(text_deltas.contains(&"Let me ".to_string()));
|
||||
assert!(text_deltas.contains(&"check...".to_string()));
|
||||
assert!(text_deltas.contains(&"Found ".to_string()));
|
||||
assert!(text_deltas.contains(&"the files!".to_string()));
|
||||
|
||||
// Verify tool events
|
||||
assert_eq!(tool_starts.len(), 1);
|
||||
assert_eq!(tool_starts[0].0, "glob");
|
||||
assert_eq!(tool_starts[0].1, "call_123");
|
||||
|
||||
assert_eq!(tool_outputs.len(), 1);
|
||||
assert_eq!(tool_outputs[0].0, "call_123");
|
||||
assert!(!tool_outputs[0].2); // not an error
|
||||
|
||||
assert_eq!(tool_ends.len(), 1);
|
||||
assert_eq!(tool_ends[0].0, "call_123");
|
||||
assert!(tool_ends[0].1); // success
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_channel_creation() {
|
||||
let (tx, mut rx) = create_event_channel();
|
||||
|
||||
// Test that channel works
|
||||
tx.send(AgentEvent::TextDelta("test".to_string()))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let event = rx.recv().await.unwrap();
|
||||
match event {
|
||||
AgentEvent::TextDelta(text) => assert_eq!(text, "test"),
|
||||
_ => panic!("Wrong event type"),
|
||||
}
|
||||
}
|
||||
114
crates/core/agent/tests/tool_context.rs
Normal file
114
crates/core/agent/tests/tool_context.rs
Normal file
@@ -0,0 +1,114 @@
|
||||
// Test that ToolContext properly wires up the placeholder tools
|
||||
use agent_core::{ToolContext, execute_tool};
|
||||
use permissions::{Mode, PermissionManager};
|
||||
use tools_todo::{TodoList, TodoStatus};
|
||||
use tools_bash::ShellManager;
|
||||
use serde_json::json;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_with_context() {
|
||||
let todo_list = TodoList::new();
|
||||
let ctx = ToolContext::new().with_todo_list(todo_list.clone());
|
||||
let perms = PermissionManager::new(Mode::Code); // Allow all tools
|
||||
|
||||
let arguments = json!({
|
||||
"todos": [
|
||||
{
|
||||
"content": "First task",
|
||||
"status": "pending",
|
||||
"active_form": "Working on first task"
|
||||
},
|
||||
{
|
||||
"content": "Second task",
|
||||
"status": "in_progress",
|
||||
"active_form": "Working on second task"
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "TodoWrite should succeed: {:?}", result);
|
||||
|
||||
// Verify the todos were written
|
||||
let todos = todo_list.read();
|
||||
assert_eq!(todos.len(), 2);
|
||||
assert_eq!(todos[0].content, "First task");
|
||||
assert_eq!(todos[1].status, TodoStatus::InProgress);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_todo_write_without_context() {
|
||||
let ctx = ToolContext::new(); // No todo_list
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"todos": []
|
||||
});
|
||||
|
||||
let result = execute_tool("todo_write", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "TodoWrite should fail without TodoList");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell and run a command
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
let _ = manager.execute(&shell_id, "echo test", None).await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "BashOutput should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bash_output_without_context() {
|
||||
let ctx = ToolContext::new(); // No shell_manager
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": "fake-id"
|
||||
});
|
||||
|
||||
let result = execute_tool("bash_output", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "BashOutput should fail without ShellManager");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kill_shell_with_context() {
|
||||
let manager = ShellManager::new();
|
||||
let ctx = ToolContext::new().with_shell_manager(manager.clone());
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
// Start a shell
|
||||
let shell_id = manager.start_shell().await.unwrap();
|
||||
|
||||
let arguments = json!({
|
||||
"shell_id": shell_id
|
||||
});
|
||||
|
||||
let result = execute_tool("kill_shell", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_ok(), "KillShell should succeed: {:?}", result);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_ask_user_without_context() {
|
||||
let ctx = ToolContext::new(); // No ask_sender
|
||||
let perms = PermissionManager::new(Mode::Code);
|
||||
|
||||
let arguments = json!({
|
||||
"questions": []
|
||||
});
|
||||
|
||||
let result = execute_tool("ask_user", &arguments, &perms, &ctx).await;
|
||||
assert!(result.is_err(), "AskUser should fail without AskSender");
|
||||
assert!(result.unwrap_err().to_string().contains("not available"));
|
||||
}
|
||||
16
crates/integration/mcp-client/Cargo.toml
Normal file
16
crates/integration/mcp-client/Cargo.toml
Normal file
@@ -0,0 +1,16 @@
|
||||
[package]
|
||||
name = "mcp-client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
tokio = { version = "1.39", features = ["process", "io-util", "sync", "time"] }
|
||||
color-eyre = "0.6"
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = "3.23.0"
|
||||
tokio = { version = "1.39", features = ["macros", "rt-multi-thread"] }
|
||||
272
crates/integration/mcp-client/src/lib.rs
Normal file
272
crates/integration/mcp-client/src/lib.rs
Normal file
@@ -0,0 +1,272 @@
|
||||
use color_eyre::eyre::{Result, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::process::Stdio;
|
||||
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
|
||||
use tokio::process::{Child, Command};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
/// JSON-RPC 2.0 request
|
||||
#[derive(Debug, Serialize)]
|
||||
struct JsonRpcRequest {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
method: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
params: Option<Value>,
|
||||
}
|
||||
|
||||
/// JSON-RPC 2.0 response
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonRpcResponse {
|
||||
jsonrpc: String,
|
||||
id: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
result: Option<Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
error: Option<JsonRpcError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct JsonRpcError {
|
||||
code: i32,
|
||||
message: String,
|
||||
}
|
||||
|
||||
/// MCP server capabilities
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ServerCapabilities {
|
||||
#[serde(default)]
|
||||
pub tools: Option<ToolsCapability>,
|
||||
#[serde(default)]
|
||||
pub resources: Option<ResourcesCapability>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ToolsCapability {
|
||||
#[serde(default)]
|
||||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct ResourcesCapability {
|
||||
#[serde(default)]
|
||||
pub subscribe: Option<bool>,
|
||||
#[serde(default)]
|
||||
pub list_changed: Option<bool>,
|
||||
}
|
||||
|
||||
/// MCP Tool definition
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpTool {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub input_schema: Option<Value>,
|
||||
}
|
||||
|
||||
/// MCP Resource definition
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct McpResource {
|
||||
pub uri: String,
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(default)]
|
||||
pub mime_type: Option<String>,
|
||||
}
|
||||
|
||||
/// MCP Client over stdio transport
|
||||
pub struct McpClient {
|
||||
process: Mutex<Child>,
|
||||
next_id: Mutex<u64>,
|
||||
server_name: String,
|
||||
}
|
||||
|
||||
impl McpClient {
|
||||
/// Create a new MCP client by spawning a subprocess
|
||||
pub async fn spawn(command: &str, args: &[&str], server_name: &str) -> Result<Self> {
|
||||
let mut child = Command::new(command)
|
||||
.args(args)
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
// Verify process is running
|
||||
if child.try_wait()?.is_some() {
|
||||
return Err(eyre!("MCP server process exited immediately"));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
process: Mutex::new(child),
|
||||
next_id: Mutex::new(1),
|
||||
server_name: server_name.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Initialize the MCP connection
|
||||
pub async fn initialize(&self) -> Result<ServerCapabilities> {
|
||||
let params = serde_json::json!({
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {
|
||||
"roots": {
|
||||
"listChanged": true
|
||||
}
|
||||
},
|
||||
"clientInfo": {
|
||||
"name": "owlen",
|
||||
"version": env!("CARGO_PKG_VERSION")
|
||||
}
|
||||
});
|
||||
|
||||
let response = self.send_request("initialize", Some(params)).await?;
|
||||
|
||||
let capabilities = response
|
||||
.get("capabilities")
|
||||
.ok_or_else(|| eyre!("No capabilities in initialize response"))?;
|
||||
|
||||
Ok(serde_json::from_value(capabilities.clone())?)
|
||||
}
|
||||
|
||||
/// List available tools
|
||||
pub async fn list_tools(&self) -> Result<Vec<McpTool>> {
|
||||
let response = self.send_request("tools/list", None).await?;
|
||||
|
||||
let tools = response
|
||||
.get("tools")
|
||||
.ok_or_else(|| eyre!("No tools in response"))?;
|
||||
|
||||
Ok(serde_json::from_value(tools.clone())?)
|
||||
}
|
||||
|
||||
/// Call a tool
|
||||
pub async fn call_tool(&self, name: &str, arguments: Value) -> Result<Value> {
|
||||
let params = serde_json::json!({
|
||||
"name": name,
|
||||
"arguments": arguments
|
||||
});
|
||||
|
||||
let response = self.send_request("tools/call", Some(params)).await?;
|
||||
|
||||
response
|
||||
.get("content")
|
||||
.cloned()
|
||||
.ok_or_else(|| eyre!("No content in tool call response"))
|
||||
}
|
||||
|
||||
/// List available resources
|
||||
pub async fn list_resources(&self) -> Result<Vec<McpResource>> {
|
||||
let response = self.send_request("resources/list", None).await?;
|
||||
|
||||
let resources = response
|
||||
.get("resources")
|
||||
.ok_or_else(|| eyre!("No resources in response"))?;
|
||||
|
||||
Ok(serde_json::from_value(resources.clone())?)
|
||||
}
|
||||
|
||||
/// Read a resource
|
||||
pub async fn read_resource(&self, uri: &str) -> Result<Value> {
|
||||
let params = serde_json::json!({
|
||||
"uri": uri
|
||||
});
|
||||
|
||||
let response = self.send_request("resources/read", Some(params)).await?;
|
||||
|
||||
response
|
||||
.get("contents")
|
||||
.cloned()
|
||||
.ok_or_else(|| eyre!("No contents in resource read response"))
|
||||
}
|
||||
|
||||
/// Get the server name
|
||||
pub fn server_name(&self) -> &str {
|
||||
&self.server_name
|
||||
}
|
||||
|
||||
/// Send a JSON-RPC request and get the response
|
||||
async fn send_request(&self, method: &str, params: Option<Value>) -> Result<Value> {
|
||||
let mut next_id = self.next_id.lock().await;
|
||||
let id = *next_id;
|
||||
*next_id += 1;
|
||||
drop(next_id);
|
||||
|
||||
let request = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id,
|
||||
method: method.to_string(),
|
||||
params,
|
||||
};
|
||||
|
||||
let request_json = serde_json::to_string(&request)?;
|
||||
|
||||
let mut process = self.process.lock().await;
|
||||
|
||||
// Write request
|
||||
let stdin = process.stdin.as_mut().ok_or_else(|| eyre!("No stdin"))?;
|
||||
stdin.write_all(request_json.as_bytes()).await?;
|
||||
stdin.write_all(b"\n").await?;
|
||||
stdin.flush().await?;
|
||||
|
||||
// Read response
|
||||
let stdout = process.stdout.take().ok_or_else(|| eyre!("No stdout"))?;
|
||||
let mut reader = BufReader::new(stdout);
|
||||
let mut response_line = String::new();
|
||||
reader.read_line(&mut response_line).await?;
|
||||
|
||||
// Put stdout back
|
||||
process.stdout = Some(reader.into_inner());
|
||||
|
||||
drop(process);
|
||||
|
||||
let response: JsonRpcResponse = serde_json::from_str(&response_line)?;
|
||||
|
||||
if response.id != id {
|
||||
return Err(eyre!("Response ID mismatch: expected {}, got {}", id, response.id));
|
||||
}
|
||||
|
||||
if let Some(error) = response.error {
|
||||
return Err(eyre!("MCP error {}: {}", error.code, error.message));
|
||||
}
|
||||
|
||||
response.result.ok_or_else(|| eyre!("No result in response"))
|
||||
}
|
||||
|
||||
/// Close the MCP connection
|
||||
pub async fn close(self) -> Result<()> {
|
||||
let mut process = self.process.into_inner();
|
||||
|
||||
// Close stdin to signal the server to exit
|
||||
drop(process.stdin.take());
|
||||
|
||||
// Wait for process to exit (with timeout)
|
||||
tokio::time::timeout(
|
||||
std::time::Duration::from_secs(5),
|
||||
process.wait()
|
||||
).await??;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn jsonrpc_request_serializes() {
|
||||
let req = JsonRpcRequest {
|
||||
jsonrpc: "2.0".to_string(),
|
||||
id: 1,
|
||||
method: "test".to_string(),
|
||||
params: Some(serde_json::json!({"key": "value"})),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"method\":\"test\""));
|
||||
assert!(json.contains("\"id\":1"));
|
||||
}
|
||||
}
|
||||
347
crates/integration/mcp-client/tests/mcp.rs
Normal file
347
crates/integration/mcp-client/tests/mcp.rs
Normal file
@@ -0,0 +1,347 @@
|
||||
use mcp_client::McpClient;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_server_capability_negotiation() {
|
||||
// Create a mock MCP server script
|
||||
let dir = tempdir().unwrap();
|
||||
let server_script = dir.path().join("mock_server.py");
|
||||
|
||||
let script_content = r#"#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
|
||||
def read_request():
|
||||
line = sys.stdin.readline()
|
||||
return json.loads(line)
|
||||
|
||||
def send_response(response):
|
||||
sys.stdout.write(json.dumps(response) + '\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
# Main loop
|
||||
while True:
|
||||
try:
|
||||
req = read_request()
|
||||
method = req.get('method')
|
||||
req_id = req.get('id')
|
||||
|
||||
if method == 'initialize':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'protocolVersion': '2024-11-05',
|
||||
'capabilities': {
|
||||
'tools': {'list_changed': True},
|
||||
'resources': {'subscribe': False}
|
||||
},
|
||||
'serverInfo': {
|
||||
'name': 'test-server',
|
||||
'version': '1.0.0'
|
||||
}
|
||||
}
|
||||
})
|
||||
elif method == 'tools/list':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'tools': []
|
||||
}
|
||||
})
|
||||
else:
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'error': {
|
||||
'code': -32601,
|
||||
'message': f'Method not found: {method}'
|
||||
}
|
||||
})
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
sys.stderr.write(f'Error: {e}\n')
|
||||
break
|
||||
"#;
|
||||
|
||||
fs::write(&server_script, script_content).unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
|
||||
// Connect to the server
|
||||
let client = McpClient::spawn(
|
||||
"python3",
|
||||
&[server_script.to_str().unwrap()],
|
||||
"test-server"
|
||||
).await.unwrap();
|
||||
|
||||
// Initialize
|
||||
let capabilities = client.initialize().await.unwrap();
|
||||
|
||||
// Verify capabilities
|
||||
assert!(capabilities.tools.is_some());
|
||||
assert_eq!(capabilities.tools.unwrap().list_changed, Some(true));
|
||||
|
||||
client.close().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_tool_invocation() {
|
||||
let dir = tempdir().unwrap();
|
||||
let server_script = dir.path().join("mock_server.py");
|
||||
|
||||
let script_content = r#"#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
|
||||
def read_request():
|
||||
line = sys.stdin.readline()
|
||||
return json.loads(line)
|
||||
|
||||
def send_response(response):
|
||||
sys.stdout.write(json.dumps(response) + '\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
while True:
|
||||
try:
|
||||
req = read_request()
|
||||
method = req.get('method')
|
||||
req_id = req.get('id')
|
||||
params = req.get('params', {})
|
||||
|
||||
if method == 'initialize':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'protocolVersion': '2024-11-05',
|
||||
'capabilities': {
|
||||
'tools': {}
|
||||
},
|
||||
'serverInfo': {
|
||||
'name': 'test-server',
|
||||
'version': '1.0.0'
|
||||
}
|
||||
}
|
||||
})
|
||||
elif method == 'tools/list':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'tools': [
|
||||
{
|
||||
'name': 'echo',
|
||||
'description': 'Echo the input',
|
||||
'input_schema': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'message': {'type': 'string'}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
elif method == 'tools/call':
|
||||
tool_name = params.get('name')
|
||||
arguments = params.get('arguments', {})
|
||||
if tool_name == 'echo':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'content': [
|
||||
{
|
||||
'type': 'text',
|
||||
'text': arguments.get('message', '')
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
else:
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'error': {
|
||||
'code': -32602,
|
||||
'message': f'Unknown tool: {tool_name}'
|
||||
}
|
||||
})
|
||||
else:
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'error': {
|
||||
'code': -32601,
|
||||
'message': f'Method not found: {method}'
|
||||
}
|
||||
})
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
sys.stderr.write(f'Error: {e}\n')
|
||||
break
|
||||
"#;
|
||||
|
||||
fs::write(&server_script, script_content).unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
|
||||
let client = McpClient::spawn(
|
||||
"python3",
|
||||
&[server_script.to_str().unwrap()],
|
||||
"test-server"
|
||||
).await.unwrap();
|
||||
|
||||
client.initialize().await.unwrap();
|
||||
|
||||
// List tools
|
||||
let tools = client.list_tools().await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].name, "echo");
|
||||
|
||||
// Call tool
|
||||
let result = client.call_tool(
|
||||
"echo",
|
||||
serde_json::json!({"message": "Hello, MCP!"})
|
||||
).await.unwrap();
|
||||
|
||||
// Verify result
|
||||
let content = result.as_array().unwrap();
|
||||
assert_eq!(content[0]["text"].as_str().unwrap(), "Hello, MCP!");
|
||||
|
||||
client.close().await.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mcp_resource_reads() {
|
||||
let dir = tempdir().unwrap();
|
||||
let server_script = dir.path().join("mock_server.py");
|
||||
|
||||
let script_content = r#"#!/usr/bin/env python3
|
||||
import sys
|
||||
import json
|
||||
|
||||
def read_request():
|
||||
line = sys.stdin.readline()
|
||||
return json.loads(line)
|
||||
|
||||
def send_response(response):
|
||||
sys.stdout.write(json.dumps(response) + '\n')
|
||||
sys.stdout.flush()
|
||||
|
||||
while True:
|
||||
try:
|
||||
req = read_request()
|
||||
method = req.get('method')
|
||||
req_id = req.get('id')
|
||||
params = req.get('params', {})
|
||||
|
||||
if method == 'initialize':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'protocolVersion': '2024-11-05',
|
||||
'capabilities': {
|
||||
'resources': {}
|
||||
},
|
||||
'serverInfo': {
|
||||
'name': 'test-server',
|
||||
'version': '1.0.0'
|
||||
}
|
||||
}
|
||||
})
|
||||
elif method == 'resources/list':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'resources': [
|
||||
{
|
||||
'uri': 'file:///test.txt',
|
||||
'name': 'Test File',
|
||||
'description': 'A test file',
|
||||
'mime_type': 'text/plain'
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
elif method == 'resources/read':
|
||||
uri = params.get('uri')
|
||||
if uri == 'file:///test.txt':
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'result': {
|
||||
'contents': [
|
||||
{
|
||||
'uri': uri,
|
||||
'mime_type': 'text/plain',
|
||||
'text': 'Hello from resource!'
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
else:
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'error': {
|
||||
'code': -32602,
|
||||
'message': f'Unknown resource: {uri}'
|
||||
}
|
||||
})
|
||||
else:
|
||||
send_response({
|
||||
'jsonrpc': '2.0',
|
||||
'id': req_id,
|
||||
'error': {
|
||||
'code': -32601,
|
||||
'message': f'Method not found: {method}'
|
||||
}
|
||||
})
|
||||
except EOFError:
|
||||
break
|
||||
except Exception as e:
|
||||
sys.stderr.write(f'Error: {e}\n')
|
||||
break
|
||||
"#;
|
||||
|
||||
fs::write(&server_script, script_content).unwrap();
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
fs::set_permissions(&server_script, std::fs::Permissions::from_mode(0o755)).unwrap();
|
||||
}
|
||||
|
||||
let client = McpClient::spawn(
|
||||
"python3",
|
||||
&[server_script.to_str().unwrap()],
|
||||
"test-server"
|
||||
).await.unwrap();
|
||||
|
||||
client.initialize().await.unwrap();
|
||||
|
||||
// List resources
|
||||
let resources = client.list_resources().await.unwrap();
|
||||
assert_eq!(resources.len(), 1);
|
||||
assert_eq!(resources[0].uri, "file:///test.txt");
|
||||
|
||||
// Read resource
|
||||
let contents = client.read_resource("file:///test.txt").await.unwrap();
|
||||
let contents_array = contents.as_array().unwrap();
|
||||
assert_eq!(contents_array[0]["text"].as_str().unwrap(), "Hello from resource!");
|
||||
|
||||
client.close().await.unwrap();
|
||||
}
|
||||
18
crates/llm/anthropic/Cargo.toml
Normal file
18
crates/llm/anthropic/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-anthropic"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Anthropic Claude API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time"] }
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
285
crates/llm/anthropic/src/auth.rs
Normal file
285
crates/llm/anthropic/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! Anthropic OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with Anthropic without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for Anthropic device flow
|
||||
pub struct AnthropicAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// Anthropic OAuth endpoints (these would be the real endpoints)
|
||||
const AUTH_BASE_URL: &str = "https://console.anthropic.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl AnthropicAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AnthropicAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for AnthropicAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api:read api:write", // Request API access
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending), // Treat as pending, caller should slow down
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &AnthropicAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
577
crates/llm/anthropic/src/client.rs
Normal file
577
crates/llm/anthropic/src/client.rs
Normal file
@@ -0,0 +1,577 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the Messages API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, Role, StreamChunk, Tool,
|
||||
ToolCall, ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const MESSAGES_ENDPOINT: &str = "/v1/messages";
|
||||
const API_VERSION: &str = "2023-06-01";
|
||||
const DEFAULT_MAX_TOKENS: u32 = 8192;
|
||||
|
||||
/// Anthropic Claude API client
|
||||
pub struct AnthropicClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to Anthropic format, extracting system message
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<AnthropicMessage>) {
|
||||
let mut system_content = None;
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == Role::System {
|
||||
// Collect system messages
|
||||
if let Some(content) = &msg.content {
|
||||
if let Some(existing) = &mut system_content {
|
||||
*existing = format!("{}\n\n{}", existing, content);
|
||||
} else {
|
||||
system_content = Some(content.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
anthropic_messages.push(AnthropicMessage::from(msg));
|
||||
}
|
||||
}
|
||||
|
||||
(system_content, anthropic_messages)
|
||||
}
|
||||
|
||||
/// Convert tools to Anthropic format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<AnthropicTool>> {
|
||||
tools.map(|t| t.iter().map(AnthropicTool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicClient {
|
||||
fn name(&self) -> &str {
|
||||
"anthropic"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
// Build the SSE request
|
||||
let req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request);
|
||||
|
||||
let es = EventSource::new(req).map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
// State for accumulating tool calls across deltas
|
||||
let tool_state: Arc<Mutex<Vec<PartialToolCall>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let stream = es.filter_map(move |event| {
|
||||
let tool_state = Arc::clone(&tool_state);
|
||||
async move {
|
||||
match event {
|
||||
Ok(Event::Open) => None,
|
||||
Ok(Event::Message(msg)) => {
|
||||
// Parse the SSE data as JSON
|
||||
let event: StreamEvent = match serde_json::from_str(&msg.data) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE event: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
convert_stream_event(event, &tool_state).await
|
||||
}
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => None,
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check for rate limiting
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: MessagesResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert response to common format
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for block in api_response.content {
|
||||
match block {
|
||||
ResponseContentBlock::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
ResponseContentBlock::ToolUse { id, name, input } => {
|
||||
tool_calls.push(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name,
|
||||
arguments: input,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
#[allow(dead_code)]
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
/// Convert an Anthropic stream event to our common StreamChunk format
|
||||
async fn convert_stream_event(
|
||||
event: StreamEvent,
|
||||
tool_state: &Arc<Mutex<Vec<PartialToolCall>>>,
|
||||
) -> Option<Result<StreamChunk, LlmError>> {
|
||||
match event {
|
||||
StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
match content_block {
|
||||
ContentBlockStartInfo::Text { text } => {
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
ContentBlockStartInfo::ToolUse { id, name } => {
|
||||
// Store the tool call start
|
||||
let mut state = tool_state.lock().await;
|
||||
while state.len() <= index {
|
||||
state.push(PartialToolCall::default());
|
||||
}
|
||||
state[index] = PartialToolCall {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input_json: String::new(),
|
||||
};
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: Some(id),
|
||||
function_name: Some(name),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})),
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
// Accumulate the JSON
|
||||
let mut state = tool_state.lock().await;
|
||||
if index < state.len() {
|
||||
state[index].input_json.push_str(&partial_json);
|
||||
}
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(partial_json),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
},
|
||||
|
||||
StreamEvent::MessageDelta { usage, .. } => {
|
||||
let u = usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: u,
|
||||
}))
|
||||
}
|
||||
|
||||
StreamEvent::MessageStop => Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})),
|
||||
|
||||
StreamEvent::Error { error } => Some(Err(LlmError::Api {
|
||||
message: error.message,
|
||||
code: Some(error.error_type),
|
||||
})),
|
||||
|
||||
// Ignore other events
|
||||
StreamEvent::MessageStart { .. }
|
||||
| StreamEvent::ContentBlockStop { .. }
|
||||
| StreamEvent::Ping => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known Claude models with their specifications
|
||||
fn get_claude_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "claude-opus-4-20250514".to_string(),
|
||||
display_name: Some("Claude Opus 4".to_string()),
|
||||
description: Some("Most capable model for complex tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(32_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(75.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-sonnet-4-20250514".to_string(),
|
||||
display_name: Some("Claude Sonnet 4".to_string()),
|
||||
description: Some("Best balance of performance and speed".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(64_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(15.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-haiku-3-5-20241022".to_string(),
|
||||
display_name: Some("Claude 3.5 Haiku".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(8_192),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.80),
|
||||
output_price_per_mtok: Some(4.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for AnthropicClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API with a simple request
|
||||
let reachable = if authenticated {
|
||||
// Test with a minimal message to verify auth works
|
||||
let test_messages = vec![ChatMessage::user("Hi")];
|
||||
let test_opts = ChatOptions::new(&self.model).with_max_tokens(1);
|
||||
|
||||
match self.chat(&test_messages, &test_opts, None).await {
|
||||
Ok(_) => true,
|
||||
Err(LlmError::Auth(_)) => false, // Auth failed
|
||||
Err(_) => true, // Other errors mean API is reachable
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let account = if authenticated && reachable {
|
||||
self.account_info().await.ok().flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - run 'owlen login anthropic' to authenticate".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach Anthropic API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "anthropic".to_string(),
|
||||
authenticated,
|
||||
account,
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Anthropic doesn't have a public account info endpoint
|
||||
// Return None - account info would come from OAuth token claims
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Anthropic doesn't expose usage stats via API
|
||||
// This would require the admin/billing API with different auth
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// Return known models - Anthropic doesn't have a models list endpoint
|
||||
Ok(get_claude_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_claude_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let (system, anthropic_msgs) = AnthropicClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(system, Some("You are helpful".to_string()));
|
||||
assert_eq!(anthropic_msgs.len(), 2);
|
||||
assert_eq!(anthropic_msgs[0].role, "user");
|
||||
assert_eq!(anthropic_msgs[1].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let anthropic_tools = AnthropicClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(anthropic_tools.len(), 1);
|
||||
assert_eq!(anthropic_tools[0].name, "read_file");
|
||||
assert_eq!(anthropic_tools[0].description, "Read a file's contents");
|
||||
}
|
||||
}
|
||||
12
crates/llm/anthropic/src/lib.rs
Normal file
12
crates/llm/anthropic/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for Anthropic's Claude models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
276
crates/llm/anthropic/src/types.rs
Normal file
276
crates/llm/anthropic/src/types.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
//! Anthropic API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MessagesRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<AnthropicMessage>,
|
||||
pub max_tokens: u32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<&'a str>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<AnthropicTool>>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicMessage {
|
||||
pub role: String, // "user" or "assistant"
|
||||
pub content: AnthropicContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
is_error: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: ToolInputSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInputSchema {
|
||||
#[serde(rename = "type")]
|
||||
pub schema_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessagesResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ResponseContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Event Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum StreamEvent {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: MessageStartInfo },
|
||||
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: ContentBlockStartInfo,
|
||||
},
|
||||
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: usize },
|
||||
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta {
|
||||
delta: MessageDeltaInfo,
|
||||
usage: Option<UsageInfo>,
|
||||
},
|
||||
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
|
||||
#[serde(rename = "error")]
|
||||
Error { error: ApiError },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageStartInfo {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: String,
|
||||
pub role: String,
|
||||
pub model: String,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlockStartInfo {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse { id: String, name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageDeltaInfo {
|
||||
pub stop_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for AnthropicTool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
input_schema: ToolInputSchema {
|
||||
schema_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for AnthropicMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::User | Role::System => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "user", // Tool results come as user messages in Anthropic
|
||||
};
|
||||
|
||||
// Handle tool results
|
||||
if msg.role == Role::Tool {
|
||||
if let (Some(tool_call_id), Some(content)) = (&msg.tool_call_id, &msg.content) {
|
||||
return Self {
|
||||
role: "user".to_string(),
|
||||
content: AnthropicContent::Blocks(vec![ContentBlock::ToolResult {
|
||||
tool_use_id: tool_call_id.clone(),
|
||||
content: content.clone(),
|
||||
is_error: None,
|
||||
}]),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
let mut blocks: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if let Some(text) = &msg.content {
|
||||
if !text.is_empty() {
|
||||
blocks.push(ContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for call in tool_calls {
|
||||
blocks.push(ContentBlock::ToolUse {
|
||||
id: call.id.clone(),
|
||||
name: call.function.name.clone(),
|
||||
input: call.function.arguments.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: AnthropicContent::Blocks(blocks),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: AnthropicContent::Text(msg.content.clone().unwrap_or_default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
18
crates/llm/core/Cargo.toml
Normal file
18
crates/llm/core/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "LLM provider abstraction layer for Owlen"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.0", features = ["time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.0", features = ["macros", "rt"] }
|
||||
195
crates/llm/core/examples/token_counting.rs
Normal file
195
crates/llm/core/examples/token_counting.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Token counting example
|
||||
//!
|
||||
//! This example demonstrates how to use the token counting utilities
|
||||
//! to manage LLM context windows.
|
||||
//!
|
||||
//! Run with: cargo run --example token_counting -p llm-core
|
||||
|
||||
use llm_core::{
|
||||
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
println!("=== Token Counting Example ===\n");
|
||||
|
||||
// Example 1: Basic token counting with SimpleTokenCounter
|
||||
println!("1. Basic Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let simple_counter = SimpleTokenCounter::new(8192);
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
|
||||
let token_count = simple_counter.count(text);
|
||||
println!("Text: \"{}\"", text);
|
||||
println!("Estimated tokens: {}", token_count);
|
||||
println!("Max context: {}\n", simple_counter.max_context());
|
||||
|
||||
// Example 2: Counting tokens in chat messages
|
||||
println!("2. Counting Tokens in Chat Messages");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant that provides concise answers."),
|
||||
ChatMessage::user("What is the capital of France?"),
|
||||
ChatMessage::assistant("The capital of France is Paris."),
|
||||
ChatMessage::user("What is its population?"),
|
||||
];
|
||||
|
||||
let total_tokens = simple_counter.count_messages(&messages);
|
||||
println!("Number of messages: {}", messages.len());
|
||||
println!("Total tokens (with overhead): {}\n", total_tokens);
|
||||
|
||||
// Example 3: Using ClaudeTokenCounter for Claude models
|
||||
println!("3. Claude-Specific Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let claude_counter = ClaudeTokenCounter::new();
|
||||
let claude_total = claude_counter.count_messages(&messages);
|
||||
|
||||
println!("Claude counter max context: {}", claude_counter.max_context());
|
||||
println!("Claude estimated tokens: {}\n", claude_total);
|
||||
|
||||
// Example 4: Context window management
|
||||
println!("4. Context Window Management");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut context = ContextWindow::new(8192);
|
||||
println!("Created context window with max: {} tokens", context.max());
|
||||
|
||||
// Simulate adding messages
|
||||
let conversation = vec![
|
||||
ChatMessage::user("Tell me about Rust programming."),
|
||||
ChatMessage::assistant(
|
||||
"Rust is a systems programming language focused on safety, \
|
||||
speed, and concurrency. It prevents common bugs like null pointer \
|
||||
dereferences and data races through its ownership system.",
|
||||
),
|
||||
ChatMessage::user("What are its main features?"),
|
||||
ChatMessage::assistant(
|
||||
"Rust's main features include: 1) Memory safety without garbage collection, \
|
||||
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
|
||||
5) Type inference, and 6) A powerful macro system.",
|
||||
),
|
||||
];
|
||||
|
||||
for (i, msg) in conversation.iter().enumerate() {
|
||||
let tokens = simple_counter.count_messages(&[msg.clone()]);
|
||||
context.add_tokens(tokens);
|
||||
|
||||
let role = msg.role.as_str();
|
||||
let preview = msg
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| {
|
||||
if c.len() > 50 {
|
||||
format!("{}...", &c[..50])
|
||||
} else {
|
||||
c.clone()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
println!(
|
||||
"Message {}: [{}] \"{}\"",
|
||||
i + 1,
|
||||
role,
|
||||
preview
|
||||
);
|
||||
println!(" Added {} tokens", tokens);
|
||||
println!(" Total used: {} / {}", context.used(), context.max());
|
||||
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
|
||||
println!(" Progress: {}\n", context.progress_bar(30));
|
||||
}
|
||||
|
||||
// Example 5: Checking context limits
|
||||
println!("5. Checking Context Limits");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
if context.is_near_limit(0.8) {
|
||||
println!("Warning: Context is over 80% full!");
|
||||
} else {
|
||||
println!("Context usage is below 80%");
|
||||
}
|
||||
|
||||
let remaining = context.remaining();
|
||||
println!("Remaining tokens: {}", remaining);
|
||||
|
||||
let new_message_tokens = 500;
|
||||
if context.has_room_for(new_message_tokens) {
|
||||
println!(
|
||||
"Can fit a message of {} tokens",
|
||||
new_message_tokens
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Cannot fit a message of {} tokens - would need to compact or start new context",
|
||||
new_message_tokens
|
||||
);
|
||||
}
|
||||
|
||||
// Example 6: Different counter variants
|
||||
println!("\n6. Using Different Counter Variants");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let counter_8k = SimpleTokenCounter::default_8k();
|
||||
let counter_32k = SimpleTokenCounter::with_32k();
|
||||
let counter_128k = SimpleTokenCounter::with_128k();
|
||||
|
||||
println!("8k context counter: {} tokens", counter_8k.max_context());
|
||||
println!("32k context counter: {} tokens", counter_32k.max_context());
|
||||
println!("128k context counter: {} tokens", counter_128k.max_context());
|
||||
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
|
||||
println!("\nClaude Haiku: {} tokens", haiku.max_context());
|
||||
println!("Claude Sonnet: {} tokens", sonnet.max_context());
|
||||
println!("Claude Opus: {} tokens", opus.max_context());
|
||||
|
||||
// Example 7: Managing context for a long conversation
|
||||
println!("\n7. Long Conversation Simulation");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
|
||||
let counter = SimpleTokenCounter::new(4096);
|
||||
|
||||
let mut message_count = 0;
|
||||
let mut compaction_count = 0;
|
||||
|
||||
// Simulate 20 exchanges
|
||||
for i in 0..20 {
|
||||
let user_msg = ChatMessage::user(format!(
|
||||
"This is user message number {} asking a question.",
|
||||
i + 1
|
||||
));
|
||||
let assistant_msg = ChatMessage::assistant(format!(
|
||||
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
|
||||
i + 1
|
||||
));
|
||||
|
||||
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
|
||||
|
||||
if !long_context.has_room_for(tokens_needed) {
|
||||
println!(
|
||||
"After {} messages, context is full ({}%). Compacting...",
|
||||
message_count,
|
||||
(long_context.usage_percent() * 100.0) as u32
|
||||
);
|
||||
// In a real scenario, we would compact the conversation
|
||||
// For now, just reset
|
||||
long_context.reset();
|
||||
compaction_count += 1;
|
||||
}
|
||||
|
||||
long_context.add_tokens(tokens_needed);
|
||||
message_count += 2;
|
||||
}
|
||||
|
||||
println!("Total messages: {}", message_count);
|
||||
println!("Compactions needed: {}", compaction_count);
|
||||
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
|
||||
println!("Final progress: {}", long_context.progress_bar(40));
|
||||
|
||||
println!("\n=== Example Complete ===");
|
||||
}
|
||||
796
crates/llm/core/src/lib.rs
Normal file
796
crates/llm/core/src/lib.rs
Normal file
@@ -0,0 +1,796 @@
|
||||
//! LLM Provider Abstraction Layer
|
||||
//!
|
||||
//! This crate defines the common types and traits for LLM provider integration.
|
||||
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
|
||||
//! to enable swapping providers at runtime.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
// ============================================================================
|
||||
// Public Modules
|
||||
// ============================================================================
|
||||
|
||||
pub mod retry;
|
||||
pub mod tokens;
|
||||
|
||||
// Re-export token counting types for convenience
|
||||
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
|
||||
|
||||
// Re-export retry types for convenience
|
||||
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
|
||||
|
||||
// ============================================================================
|
||||
// Error Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LlmError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
Json(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
|
||||
RateLimit { retry_after_secs: Option<u64> },
|
||||
|
||||
#[error("API error: {message}")]
|
||||
Api { message: String, code: Option<String> },
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(String),
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
Stream(String),
|
||||
|
||||
#[error("Request timeout: {0}")]
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Types
|
||||
// ============================================================================
|
||||
|
||||
/// Role of a message in the conversation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Role {
|
||||
fn from(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"system" => Role::System,
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"tool" => Role::Tool,
|
||||
_ => Role::User, // Default fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A message in the conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls made by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
/// For tool role messages: the ID of the tool call this responds to
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
/// For tool role messages: the name of the tool
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Create a system message
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message with tool calls (no text content)
|
||||
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tool result message
|
||||
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Types
|
||||
// ============================================================================
|
||||
|
||||
/// A tool call requested by the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
|
||||
/// The type of tool call (always "function" for now)
|
||||
#[serde(rename = "type", default = "default_function_type")]
|
||||
pub call_type: String,
|
||||
|
||||
/// The function being called
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
fn default_function_type() -> String {
|
||||
"function".to_string()
|
||||
}
|
||||
|
||||
/// Details of a function call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
|
||||
/// Arguments as a JSON object
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Definition of a tool available to the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Create a new function tool
|
||||
pub fn function(
|
||||
name: impl Into<String>,
|
||||
description: impl Into<String>,
|
||||
parameters: ToolParameters,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function definition within a tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
/// Parameters schema for a function
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
|
||||
/// JSON Schema properties object
|
||||
pub properties: Value,
|
||||
|
||||
/// Required parameter names
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
impl ToolParameters {
|
||||
/// Create an object parameter schema
|
||||
pub fn object(properties: Value, required: Vec<String>) -> Self {
|
||||
Self {
|
||||
param_type: "object".to_string(),
|
||||
properties,
|
||||
required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// A chunk of a streaming response
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
/// Incremental text content
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls (may be partial/streaming)
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
|
||||
/// Whether this is the final chunk
|
||||
pub done: bool,
|
||||
|
||||
/// Usage statistics (typically only in final chunk)
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDelta {
|
||||
/// Index of this tool call in the array
|
||||
pub index: usize,
|
||||
|
||||
/// Tool call ID (may only be present in first delta)
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Function name (may only be present in first delta)
|
||||
pub function_name: Option<String>,
|
||||
|
||||
/// Incremental arguments string
|
||||
pub arguments_delta: Option<String>,
|
||||
}
|
||||
|
||||
/// Token usage statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Options for a chat request
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatOptions {
|
||||
/// Model to use
|
||||
pub model: String,
|
||||
|
||||
/// Temperature (0.0 - 2.0)
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Maximum tokens to generate
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// Top-p sampling
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl ChatOptions {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temp: f32) -> Self {
|
||||
self.temperature = Some(temp);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max: u32) -> Self {
|
||||
self.max_tokens = Some(max);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Trait
|
||||
// ============================================================================
|
||||
|
||||
/// A boxed stream of chunks
|
||||
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
|
||||
|
||||
/// The main trait that all LLM providers must implement
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get the current model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Send a chat request and receive a streaming response
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The conversation history
|
||||
/// * `options` - Request options (model, temperature, etc.)
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
///
|
||||
/// # Returns
|
||||
/// A stream of response chunks
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError>;
|
||||
|
||||
/// Send a chat request and receive a complete response (non-streaming)
|
||||
///
|
||||
/// Default implementation collects the stream, but providers may override
|
||||
/// for efficiency.
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut stream = self.chat_stream(messages, options, tools).await?;
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
if let Some(text) = chunk.content {
|
||||
content.push_str(&text);
|
||||
}
|
||||
|
||||
if let Some(deltas) = chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
// Grow the tool_calls vec if needed
|
||||
while tool_calls.len() <= delta.index {
|
||||
tool_calls.push(PartialToolCall::default());
|
||||
}
|
||||
|
||||
let partial = &mut tool_calls[delta.index];
|
||||
if let Some(id) = delta.id {
|
||||
partial.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.function_name {
|
||||
partial.function_name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
partial.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.usage.is_some() {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert partial tool calls to complete tool calls
|
||||
let final_tool_calls: Vec<ToolCall> = tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|p| p.try_into_tool_call())
|
||||
.collect();
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(final_tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete chat response (non-streaming)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Helper for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
function_name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl PartialToolCall {
|
||||
fn try_into_tool_call(self) -> Option<ToolCall> {
|
||||
let id = self.id?;
|
||||
let name = self.function_name?;
|
||||
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
|
||||
|
||||
Some(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Authentication
|
||||
// ============================================================================
|
||||
|
||||
/// Authentication method for LLM providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthMethod {
|
||||
/// No authentication (for local providers like Ollama)
|
||||
None,
|
||||
|
||||
/// API key authentication
|
||||
ApiKey(String),
|
||||
|
||||
/// OAuth access token (from login flow)
|
||||
OAuth {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AuthMethod {
|
||||
/// Create API key auth
|
||||
pub fn api_key(key: impl Into<String>) -> Self {
|
||||
Self::ApiKey(key.into())
|
||||
}
|
||||
|
||||
/// Create OAuth auth from tokens
|
||||
pub fn oauth(access_token: impl Into<String>) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create OAuth auth with refresh token
|
||||
pub fn oauth_with_refresh(
|
||||
access_token: impl Into<String>,
|
||||
refresh_token: impl Into<String>,
|
||||
expires_at: Option<u64>,
|
||||
) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: Some(refresh_token.into()),
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bearer token for Authorization header
|
||||
pub fn bearer_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::None => None,
|
||||
Self::ApiKey(key) => Some(key),
|
||||
Self::OAuth { access_token, .. } => Some(access_token),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if token might need refresh
|
||||
pub fn needs_refresh(&self) -> bool {
|
||||
match self {
|
||||
Self::OAuth {
|
||||
expires_at: Some(exp),
|
||||
refresh_token: Some(_),
|
||||
..
|
||||
} => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
// Refresh if expiring within 5 minutes
|
||||
*exp < now + 300
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device code response for OAuth device flow
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeviceCodeResponse {
|
||||
/// Code the user enters on the verification page
|
||||
pub user_code: String,
|
||||
|
||||
/// URL the user visits to authorize
|
||||
pub verification_uri: String,
|
||||
|
||||
/// Full URL with code pre-filled (if supported)
|
||||
pub verification_uri_complete: Option<String>,
|
||||
|
||||
/// Device code for polling (internal use)
|
||||
pub device_code: String,
|
||||
|
||||
/// How often to poll (in seconds)
|
||||
pub interval: u64,
|
||||
|
||||
/// When the codes expire (in seconds)
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Result of polling for device authorization
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DeviceAuthResult {
|
||||
/// Still waiting for user to authorize
|
||||
Pending,
|
||||
|
||||
/// User authorized, here are the tokens
|
||||
Success {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_in: Option<u64>,
|
||||
},
|
||||
|
||||
/// User denied authorization
|
||||
Denied,
|
||||
|
||||
/// Code expired
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Trait for providers that support OAuth device flow
|
||||
#[async_trait]
|
||||
pub trait OAuthProvider {
|
||||
/// Start the device authorization flow
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
|
||||
|
||||
/// Poll for the authorization result
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
|
||||
|
||||
/// Refresh an access token using a refresh token
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
|
||||
}
|
||||
|
||||
/// Stored credentials for a provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredCredentials {
|
||||
pub provider: String,
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Status & Info
|
||||
// ============================================================================
|
||||
|
||||
/// Status information for a provider connection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStatus {
|
||||
/// Provider name
|
||||
pub provider: String,
|
||||
|
||||
/// Whether the connection is authenticated
|
||||
pub authenticated: bool,
|
||||
|
||||
/// Current user/account info if authenticated
|
||||
pub account: Option<AccountInfo>,
|
||||
|
||||
/// Current model being used
|
||||
pub model: String,
|
||||
|
||||
/// API endpoint URL
|
||||
pub endpoint: String,
|
||||
|
||||
/// Whether the provider is reachable
|
||||
pub reachable: bool,
|
||||
|
||||
/// Any status message or error
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Account/user information from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccountInfo {
|
||||
/// Account/user ID
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Display name or email
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Account email
|
||||
pub email: Option<String>,
|
||||
|
||||
/// Account type (free, pro, team, enterprise)
|
||||
pub account_type: Option<String>,
|
||||
|
||||
/// Organization name if applicable
|
||||
pub organization: Option<String>,
|
||||
}
|
||||
|
||||
/// Usage statistics from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsageStats {
|
||||
/// Total tokens used in current period
|
||||
pub tokens_used: Option<u64>,
|
||||
|
||||
/// Token limit for current period (if applicable)
|
||||
pub token_limit: Option<u64>,
|
||||
|
||||
/// Number of requests made
|
||||
pub requests_made: Option<u64>,
|
||||
|
||||
/// Request limit (if applicable)
|
||||
pub request_limit: Option<u64>,
|
||||
|
||||
/// Cost incurred (if available)
|
||||
pub cost_usd: Option<f64>,
|
||||
|
||||
/// Period start timestamp
|
||||
pub period_start: Option<u64>,
|
||||
|
||||
/// Period end timestamp
|
||||
pub period_end: Option<u64>,
|
||||
}
|
||||
|
||||
/// Available model information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model ID/name
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable display name
|
||||
pub display_name: Option<String>,
|
||||
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Context window size (tokens)
|
||||
pub context_window: Option<u32>,
|
||||
|
||||
/// Max output tokens
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Whether the model supports tool use
|
||||
pub supports_tools: bool,
|
||||
|
||||
/// Whether the model supports vision/images
|
||||
pub supports_vision: bool,
|
||||
|
||||
/// Input token price per 1M tokens (USD)
|
||||
pub input_price_per_mtok: Option<f64>,
|
||||
|
||||
/// Output token price per 1M tokens (USD)
|
||||
pub output_price_per_mtok: Option<f64>,
|
||||
}
|
||||
|
||||
/// Trait for providers that support status/info queries
|
||||
#[async_trait]
|
||||
pub trait ProviderInfo {
|
||||
/// Get the current connection status
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError>;
|
||||
|
||||
/// Get account information (if authenticated)
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
|
||||
|
||||
/// Get usage statistics (if available)
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
|
||||
|
||||
/// List available models
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
|
||||
|
||||
/// Check if a specific model is available
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = self.list_models().await?;
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Supported LLM providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProviderType {
|
||||
Ollama,
|
||||
Anthropic,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ollama" => Some(Self::Ollama),
|
||||
"anthropic" | "claude" => Some(Self::Anthropic),
|
||||
"openai" | "gpt" => Some(Self::OpenAI),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "ollama",
|
||||
Self::Anthropic => "anthropic",
|
||||
Self::OpenAI => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
/// Default model for this provider
|
||||
pub fn default_model(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "qwen3:8b",
|
||||
Self::Anthropic => "claude-sonnet-4-20250514",
|
||||
Self::OpenAI => "gpt-4o",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProviderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
386
crates/llm/core/src/retry.rs
Normal file
386
crates/llm/core/src/retry.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Error recovery and retry logic for LLM operations
|
||||
//!
|
||||
//! This module provides configurable retry strategies with exponential backoff
|
||||
//! for handling transient failures when communicating with LLM providers.
|
||||
|
||||
use crate::LlmError;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for retry behavior
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: u32,
|
||||
/// Initial delay before first retry (in milliseconds)
|
||||
pub initial_delay_ms: u64,
|
||||
/// Maximum delay between retries (in milliseconds)
|
||||
pub max_delay_ms: u64,
|
||||
/// Multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 30000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Create a new retry configuration with custom values
|
||||
pub fn new(
|
||||
max_retries: u32,
|
||||
initial_delay_ms: u64,
|
||||
max_delay_ms: u64,
|
||||
backoff_multiplier: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
max_retries,
|
||||
initial_delay_ms,
|
||||
max_delay_ms,
|
||||
backoff_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with no retries
|
||||
pub fn no_retry() -> Self {
|
||||
Self {
|
||||
max_retries: 0,
|
||||
initial_delay_ms: 0,
|
||||
max_delay_ms: 0,
|
||||
backoff_multiplier: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with aggressive retries for rate-limited scenarios
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
max_retries: 5,
|
||||
initial_delay_ms: 2000,
|
||||
max_delay_ms: 60000,
|
||||
backoff_multiplier: 2.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether an error is retryable
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `error` - The error to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the error is transient and the operation should be retried,
|
||||
/// `false` if the error is permanent and retrying won't help
|
||||
pub fn is_retryable_error(error: &LlmError) -> bool {
|
||||
match error {
|
||||
// Always retry rate limits
|
||||
LlmError::RateLimit { .. } => true,
|
||||
|
||||
// Always retry timeouts
|
||||
LlmError::Timeout(_) => true,
|
||||
|
||||
// Retry HTTP errors that are server-side (5xx)
|
||||
LlmError::Http(msg) => {
|
||||
// Check if the error message contains a 5xx status code
|
||||
msg.contains("500")
|
||||
|| msg.contains("502")
|
||||
|| msg.contains("503")
|
||||
|| msg.contains("504")
|
||||
|| msg.contains("Internal Server Error")
|
||||
|| msg.contains("Bad Gateway")
|
||||
|| msg.contains("Service Unavailable")
|
||||
|| msg.contains("Gateway Timeout")
|
||||
}
|
||||
|
||||
// Don't retry authentication errors - they need user intervention
|
||||
LlmError::Auth(_) => false,
|
||||
|
||||
// Don't retry JSON parsing errors - the data is malformed
|
||||
LlmError::Json(_) => false,
|
||||
|
||||
// Don't retry API errors - these are typically client-side issues
|
||||
LlmError::Api { .. } => false,
|
||||
|
||||
// Provider errors might be transient, but we conservatively don't retry
|
||||
LlmError::Provider(_) => false,
|
||||
|
||||
// Stream errors are typically not retryable
|
||||
LlmError::Stream(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for retrying failed operations with exponential backoff
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryStrategy {
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryStrategy {
|
||||
/// Create a new retry strategy with the given configuration
|
||||
pub fn new(config: RetryConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Create a retry strategy with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(RetryConfig::default())
|
||||
}
|
||||
|
||||
/// Execute an async operation with retries
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - A function that returns a Future producing a Result
|
||||
///
|
||||
/// # Returns
|
||||
/// The result of the operation, or the last error if all retries fail
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let strategy = RetryStrategy::default_config();
|
||||
/// let result = strategy.execute(|| async {
|
||||
/// // Your LLM API call here
|
||||
/// llm_client.chat(&messages, &options, None).await
|
||||
/// }).await?;
|
||||
/// ```
|
||||
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, LlmError>>,
|
||||
{
|
||||
let mut attempt = 0;
|
||||
|
||||
loop {
|
||||
// Try the operation
|
||||
match operation().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(err) => {
|
||||
// Check if we should retry
|
||||
if !is_retryable_error(&err) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
|
||||
// Check if we've exhausted retries
|
||||
if attempt > self.config.max_retries {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff and jitter
|
||||
let delay = self.delay_for_attempt(attempt);
|
||||
|
||||
// Log retry attempt (in a real implementation, you might use tracing)
|
||||
eprintln!(
|
||||
"Retry attempt {}/{} after {:?}",
|
||||
attempt, self.config.max_retries, delay
|
||||
);
|
||||
|
||||
// Sleep before next attempt
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the delay for a given attempt number with jitter
|
||||
///
|
||||
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
|
||||
/// Adds random jitter of ±10% to prevent thundering herd problems
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `attempt` - The attempt number (1-indexed)
|
||||
///
|
||||
/// # Returns
|
||||
/// The delay duration to wait before the next retry
|
||||
fn delay_for_attempt(&self, attempt: u32) -> Duration {
|
||||
// Calculate base delay with exponential backoff
|
||||
let base_delay_ms = self.config.initial_delay_ms as f64
|
||||
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
|
||||
|
||||
// Cap at max_delay_ms
|
||||
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
|
||||
|
||||
// Add jitter: ±10%
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_factor = rng.gen_range(0.9..=1.1);
|
||||
let final_delay_ms = capped_delay_ms * jitter_factor;
|
||||
|
||||
Duration::from_millis(final_delay_ms as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_default_retry_config() {
|
||||
let config = RetryConfig::default();
|
||||
assert_eq!(config.max_retries, 3);
|
||||
assert_eq!(config.initial_delay_ms, 1000);
|
||||
assert_eq!(config.max_delay_ms, 30000);
|
||||
assert_eq!(config.backoff_multiplier, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_retry_config() {
|
||||
let config = RetryConfig::no_retry();
|
||||
assert_eq!(config.max_retries, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_retryable_error() {
|
||||
// Retryable errors
|
||||
assert!(is_retryable_error(&LlmError::RateLimit {
|
||||
retry_after_secs: Some(60)
|
||||
}));
|
||||
assert!(is_retryable_error(&LlmError::Timeout(
|
||||
"Request timed out".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"500 Internal Server Error".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"503 Service Unavailable".to_string()
|
||||
)));
|
||||
|
||||
// Non-retryable errors
|
||||
assert!(!is_retryable_error(&LlmError::Auth(
|
||||
"Invalid API key".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Json(
|
||||
"Invalid JSON".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Api {
|
||||
message: "Invalid request".to_string(),
|
||||
code: Some("400".to_string())
|
||||
}));
|
||||
assert!(!is_retryable_error(&LlmError::Http(
|
||||
"400 Bad Request".to_string()
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_calculation() {
|
||||
let config = RetryConfig::default();
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Test that delays increase exponentially
|
||||
let delay1 = strategy.delay_for_attempt(1);
|
||||
let delay2 = strategy.delay_for_attempt(2);
|
||||
let delay3 = strategy.delay_for_attempt(3);
|
||||
|
||||
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
|
||||
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
|
||||
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
|
||||
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_max_cap() {
|
||||
let config = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 5000,
|
||||
backoff_multiplier: 2.0,
|
||||
};
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Even with high attempt numbers, delay should be capped
|
||||
let delay = strategy.delay_for_attempt(10);
|
||||
assert!(delay.as_millis() <= 5500); // max + jitter
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_on_first_attempt() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok::<_, LlmError>(42)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_after_retries() {
|
||||
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if current < 3 {
|
||||
Err(LlmError::Timeout("Timeout".to_string()))
|
||||
} else {
|
||||
Ok(42)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_exhausted() {
|
||||
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_retryable_error() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
|
||||
}
|
||||
}
|
||||
607
crates/llm/core/src/tokens.rs
Normal file
607
crates/llm/core/src/tokens.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
//! Token counting utilities for LLM context management
|
||||
//!
|
||||
//! This module provides token counting abstractions and implementations for
|
||||
//! managing LLM context windows. Token counters estimate token usage without
|
||||
//! requiring external tokenization libraries, using heuristic-based approaches.
|
||||
|
||||
use crate::ChatMessage;
|
||||
|
||||
// ============================================================================
|
||||
// TokenCounter Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for counting tokens in text and chat messages
|
||||
///
|
||||
/// Implementations provide model-specific token counting logic to help
|
||||
/// manage context windows and estimate API costs.
|
||||
pub trait TokenCounter: Send + Sync {
|
||||
/// Count tokens in a string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `text` - The text to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated number of tokens
|
||||
fn count(&self, text: &str) -> usize;
|
||||
|
||||
/// Count tokens in chat messages
|
||||
///
|
||||
/// This accounts for both the message content and the overhead
|
||||
/// from the chat message structure (roles, delimiters, etc.).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The messages to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated total tokens including message structure overhead
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
|
||||
|
||||
/// Get the model's max context window size
|
||||
///
|
||||
/// # Returns
|
||||
/// Maximum number of tokens the model can handle
|
||||
fn max_context(&self) -> usize;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SimpleTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// A basic token counter using simple heuristics
|
||||
///
|
||||
/// This counter uses the rule of thumb that English text averages about
|
||||
/// 4 characters per token. It adds overhead for message structure.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let text = "Hello, world!";
|
||||
/// let tokens = counter.count(text);
|
||||
/// assert!(tokens > 0);
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("What is the weather?"),
|
||||
/// ChatMessage::assistant("I don't have access to weather data."),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// assert!(total > 0);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimpleTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl SimpleTokenCounter {
|
||||
/// Create a new simple token counter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size for the model
|
||||
pub fn new(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a token counter with a default 8192 token context
|
||||
pub fn default_8k() -> Self {
|
||||
Self::new(8192)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 32k token context
|
||||
pub fn with_32k() -> Self {
|
||||
Self::new(32768)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 128k token context
|
||||
pub fn with_128k() -> Self {
|
||||
Self::new(131072)
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for SimpleTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Estimate: approximately 4 characters per token for English
|
||||
// Add 3 before dividing to round up
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Base overhead for message formatting (estimated)
|
||||
// Each message has role, delimiters, etc.
|
||||
const MESSAGE_OVERHEAD: usize = 4;
|
||||
|
||||
for msg in messages {
|
||||
// Count role
|
||||
total += MESSAGE_OVERHEAD;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls (more expensive due to JSON structure)
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
// ID overhead
|
||||
total += self.count(&tc.id);
|
||||
// Function name
|
||||
total += self.count(&tc.function.name);
|
||||
// Arguments (JSON serialized, add 20% overhead for JSON structure)
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Count tool call id for tool result messages
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
// Count tool name for tool result messages
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ClaudeTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// Token counter optimized for Anthropic Claude models
|
||||
///
|
||||
/// Claude models have specific tokenization characteristics and overhead.
|
||||
/// This counter adjusts the estimates accordingly.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = ClaudeTokenCounter::new();
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::system("You are a helpful assistant."),
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl ClaudeTokenCounter {
|
||||
/// Create a new Claude token counter with default 200k context
|
||||
///
|
||||
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_context: 200_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Claude counter with a custom context window
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size
|
||||
pub fn with_context(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3 Haiku (200k context)
|
||||
pub fn haiku() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3.5 Sonnet (200k context)
|
||||
pub fn sonnet() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 4 Opus (200k context)
|
||||
pub fn opus() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeTokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for ClaudeTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Claude's tokenization is similar to the 4 chars/token heuristic
|
||||
// but tends to be slightly more efficient with structured content
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Claude has specific message formatting overhead
|
||||
const MESSAGE_OVERHEAD: usize = 5;
|
||||
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
|
||||
|
||||
for msg in messages {
|
||||
// Different overhead for system vs other messages
|
||||
let overhead = if matches!(msg.role, crate::Role::System) {
|
||||
SYSTEM_MESSAGE_OVERHEAD
|
||||
} else {
|
||||
MESSAGE_OVERHEAD
|
||||
};
|
||||
|
||||
total += overhead;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
// Claude's tool call format has additional overhead
|
||||
const TOOL_CALL_OVERHEAD: usize = 10;
|
||||
|
||||
for tc in tool_calls {
|
||||
total += TOOL_CALL_OVERHEAD;
|
||||
total += self.count(&tc.id);
|
||||
total += self.count(&tc.function.name);
|
||||
|
||||
// Arguments with JSON structure overhead
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool result overhead
|
||||
if msg.tool_call_id.is_some() {
|
||||
const TOOL_RESULT_OVERHEAD: usize = 8;
|
||||
total += TOOL_RESULT_OVERHEAD;
|
||||
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ContextWindow
|
||||
// ============================================================================
|
||||
|
||||
/// Manages context window tracking for a conversation
|
||||
///
|
||||
/// Helps monitor token usage and determine when context limits are approaching.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let mut window = ContextWindow::new(counter.max_context());
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ChatMessage::assistant("Hi there!"),
|
||||
/// ];
|
||||
///
|
||||
/// let tokens = counter.count_messages(&messages);
|
||||
/// window.add_tokens(tokens);
|
||||
///
|
||||
/// println!("Used: {} tokens", window.used());
|
||||
/// println!("Remaining: {} tokens", window.remaining());
|
||||
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
|
||||
///
|
||||
/// if window.is_near_limit(0.8) {
|
||||
/// println!("Warning: Context is 80% full!");
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextWindow {
|
||||
/// Number of tokens currently used
|
||||
used: usize,
|
||||
/// Maximum number of tokens allowed
|
||||
max: usize,
|
||||
}
|
||||
|
||||
impl ContextWindow {
|
||||
/// Create a new context window tracker
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size in tokens
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self { used: 0, max }
|
||||
}
|
||||
|
||||
/// Create a context window with initial usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size
|
||||
/// * `used` - Initial number of tokens used
|
||||
pub fn with_usage(max: usize, used: usize) -> Self {
|
||||
Self { used, max }
|
||||
}
|
||||
|
||||
/// Get the number of tokens currently used
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Get the maximum number of tokens
|
||||
pub fn max(&self) -> usize {
|
||||
self.max
|
||||
}
|
||||
|
||||
/// Get the number of remaining tokens
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.max.saturating_sub(self.used)
|
||||
}
|
||||
|
||||
/// Get the usage as a percentage (0.0 to 1.0)
|
||||
///
|
||||
/// Returns the fraction of the context window that is currently used.
|
||||
pub fn usage_percent(&self) -> f32 {
|
||||
if self.max == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.used as f32 / self.max as f32
|
||||
}
|
||||
|
||||
/// Check if usage is near the limit
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
|
||||
/// 0.8 means "is usage > 80%?"
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the current usage exceeds the threshold percentage
|
||||
pub fn is_near_limit(&self, threshold: f32) -> bool {
|
||||
self.usage_percent() > threshold
|
||||
}
|
||||
|
||||
/// Add tokens to the usage count
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens to add
|
||||
pub fn add_tokens(&mut self, tokens: usize) {
|
||||
self.used = self.used.saturating_add(tokens);
|
||||
}
|
||||
|
||||
/// Set the current usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `used` - Number of tokens currently used
|
||||
pub fn set_used(&mut self, used: usize) {
|
||||
self.used = used;
|
||||
}
|
||||
|
||||
/// Reset the usage counter to zero
|
||||
pub fn reset(&mut self) {
|
||||
self.used = 0;
|
||||
}
|
||||
|
||||
/// Check if there's enough room for additional tokens
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens needed
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if adding these tokens would stay within the limit
|
||||
pub fn has_room_for(&self, tokens: usize) -> bool {
|
||||
self.used.saturating_add(tokens) <= self.max
|
||||
}
|
||||
|
||||
/// Get a visual progress bar representation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `width` - Width of the progress bar in characters
|
||||
///
|
||||
/// # Returns
|
||||
/// A string with a simple text-based progress bar
|
||||
pub fn progress_bar(&self, width: usize) -> String {
|
||||
if width == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let percent = self.usage_percent();
|
||||
let filled = ((percent * width as f32) as usize).min(width);
|
||||
let empty = width - filled;
|
||||
|
||||
format!(
|
||||
"[{}{}] {:.1}%",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(empty),
|
||||
percent * 100.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ChatMessage, FunctionCall, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_basic() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
// Empty string
|
||||
assert_eq!(counter.count(""), 0);
|
||||
|
||||
// Short string (~4 chars/token)
|
||||
let text = "Hello, world!"; // 13 chars -> ~4 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 3 && count <= 5);
|
||||
|
||||
// Longer text
|
||||
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 10 && count <= 13);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_messages() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello!"),
|
||||
ChatMessage::assistant("Hi there! How can I help you today?"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
|
||||
// Should be more than just the text due to overhead
|
||||
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
|
||||
assert!(total > text_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_with_tool_calls() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "read_file".to_string(),
|
||||
arguments: json!({"path": "/etc/hosts"}),
|
||||
},
|
||||
};
|
||||
|
||||
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
assert_eq!(counter.max_context(), 200_000);
|
||||
|
||||
let text = "Hello, Claude!";
|
||||
let count = counter.count(text);
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_system_message() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant."),
|
||||
ChatMessage::user("Hello!"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window() {
|
||||
let mut window = ContextWindow::new(1000);
|
||||
|
||||
assert_eq!(window.used(), 0);
|
||||
assert_eq!(window.max(), 1000);
|
||||
assert_eq!(window.remaining(), 1000);
|
||||
assert_eq!(window.usage_percent(), 0.0);
|
||||
|
||||
window.add_tokens(200);
|
||||
assert_eq!(window.used(), 200);
|
||||
assert_eq!(window.remaining(), 800);
|
||||
assert_eq!(window.usage_percent(), 0.2);
|
||||
|
||||
window.add_tokens(600);
|
||||
assert_eq!(window.used(), 800);
|
||||
assert!(window.is_near_limit(0.7));
|
||||
assert!(!window.is_near_limit(0.9));
|
||||
|
||||
assert!(window.has_room_for(200));
|
||||
assert!(!window.has_room_for(300));
|
||||
|
||||
window.reset();
|
||||
assert_eq!(window.used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_progress_bar() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
window.add_tokens(50);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("====="));
|
||||
assert!(bar.contains("50.0%"));
|
||||
|
||||
window.add_tokens(40);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("========="));
|
||||
assert!(bar.contains("90.0%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_saturation() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
// Adding more tokens than max should saturate, not overflow
|
||||
window.add_tokens(150);
|
||||
assert_eq!(window.used(), 150);
|
||||
assert_eq!(window.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_constructors() {
|
||||
let counter1 = SimpleTokenCounter::default_8k();
|
||||
assert_eq!(counter1.max_context(), 8192);
|
||||
|
||||
let counter2 = SimpleTokenCounter::with_32k();
|
||||
assert_eq!(counter2.max_context(), 32768);
|
||||
|
||||
let counter3 = SimpleTokenCounter::with_128k();
|
||||
assert_eq!(counter3.max_context(), 131072);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_variants() {
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
assert_eq!(haiku.max_context(), 200_000);
|
||||
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
assert_eq!(sonnet.max_context(), 200_000);
|
||||
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
assert_eq!(opus.max_context(), 200_000);
|
||||
|
||||
let custom = ClaudeTokenCounter::with_context(100_000);
|
||||
assert_eq!(custom.max_context(), 100_000);
|
||||
}
|
||||
}
|
||||
22
crates/llm/ollama/.gitignore
vendored
Normal file
22
crates/llm/ollama/.gitignore
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
/target
|
||||
### Rust template
|
||||
# Generated by Cargo
|
||||
# will have compiled files and executables
|
||||
debug/
|
||||
target/
|
||||
|
||||
# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries
|
||||
# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html
|
||||
Cargo.lock
|
||||
|
||||
# These are backup files generated by rustfmt
|
||||
**/*.rs.bk
|
||||
|
||||
# MSVC Windows builds of rustc generate these, which store debugging information
|
||||
*.pdb
|
||||
|
||||
### rust-analyzer template
|
||||
# Can be generated by other build systems other than cargo (ex: bazelbuild/rust_rules)
|
||||
rust-project.json
|
||||
|
||||
|
||||
18
crates/llm/ollama/Cargo.toml
Normal file
18
crates/llm/ollama/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-ollama"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
tokio = { version = "1.39", features = ["rt-multi-thread", "macros"] }
|
||||
futures = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
bytes = "1"
|
||||
tokio-stream = "0.1.17"
|
||||
async-trait = "0.1"
|
||||
329
crates/llm/ollama/src/client.rs
Normal file
329
crates/llm/ollama/src/client.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use crate::types::{ChatMessage, ChatResponseChunk, Tool};
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use async_trait::async_trait;
|
||||
use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError, ChatOptions, ChunkStream,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OllamaClient {
|
||||
http: Client,
|
||||
base_url: String, // e.g. "http://localhost:11434"
|
||||
api_key: Option<String>, // For Ollama Cloud authentication
|
||||
current_model: String, // Default model for this client
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct OllamaOptions {
|
||||
pub model: String,
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum OllamaError {
|
||||
#[error("http: {0}")]
|
||||
Http(#[from] reqwest::Error),
|
||||
#[error("json: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("protocol: {0}")]
|
||||
Protocol(String),
|
||||
}
|
||||
|
||||
// Convert OllamaError to LlmError
|
||||
impl From<OllamaError> for LlmError {
|
||||
fn from(err: OllamaError) -> Self {
|
||||
match err {
|
||||
OllamaError::Http(e) => LlmError::Http(e.to_string()),
|
||||
OllamaError::Json(e) => LlmError::Json(e.to_string()),
|
||||
OllamaError::Protocol(msg) => LlmError::Provider(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
base_url: base_url.into().trim_end_matches('/').to_string(),
|
||||
api_key: None,
|
||||
current_model: "qwen3:8b".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_api_key(mut self, api_key: impl Into<String>) -> Self {
|
||||
self.api_key = Some(api_key.into());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.current_model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cloud() -> Self {
|
||||
// Same API, different base
|
||||
Self::new("https://ollama.com")
|
||||
}
|
||||
|
||||
pub async fn chat_stream_raw(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
opts: &OllamaOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<impl Stream<Item = Result<ChatResponseChunk, OllamaError>>, OllamaError> {
|
||||
#[derive(Serialize)]
|
||||
struct Body<'a> {
|
||||
model: &'a str,
|
||||
messages: &'a [ChatMessage],
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let body = Body {model: &opts.model, messages, stream: true, tools};
|
||||
let mut req = self.http.post(url).json(&body);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await?;
|
||||
let bytes_stream = resp.bytes_stream();
|
||||
|
||||
// NDJSON parser: split by '\n', parse each as JSON and stream the results
|
||||
let out = bytes_stream
|
||||
.map_err(OllamaError::Http)
|
||||
.map_ok(|bytes| {
|
||||
// Convert the chunk to a UTF‑8 string and own it
|
||||
let txt = String::from_utf8_lossy(&bytes).into_owned();
|
||||
// Parse each non‑empty line into a ChatResponseChunk
|
||||
let results: Vec<Result<ChatResponseChunk, OllamaError>> = txt
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
serde_json::from_str::<ChatResponseChunk>(trimmed)
|
||||
.map_err(OllamaError::Json),
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
futures::stream::iter(results)
|
||||
})
|
||||
.try_flatten(); // Stream<Item = Result<ChatResponseChunk, OllamaError>>
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LlmProvider Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaClient {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.current_model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[llm_core::ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[llm_core::Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
// Convert llm_core messages to Ollama messages
|
||||
let ollama_messages: Vec<ChatMessage> = messages.iter().map(|m| m.into()).collect();
|
||||
|
||||
// Convert llm_core tools to Ollama tools if present
|
||||
let ollama_tools: Option<Vec<Tool>> = tools.map(|tools| {
|
||||
tools.iter().map(|t| Tool {
|
||||
tool_type: t.tool_type.clone(),
|
||||
function: crate::types::ToolFunction {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: crate::types::ToolParameters {
|
||||
param_type: t.function.parameters.param_type.clone(),
|
||||
properties: t.function.parameters.properties.clone(),
|
||||
required: t.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
let opts = OllamaOptions {
|
||||
model: options.model.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Make the request and build the body inline to avoid lifetime issues
|
||||
#[derive(Serialize)]
|
||||
struct Body<'a> {
|
||||
model: &'a str,
|
||||
messages: &'a [ChatMessage],
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let body = Body {
|
||||
model: &opts.model,
|
||||
messages: &ollama_messages,
|
||||
stream: true,
|
||||
tools: ollama_tools.as_deref(),
|
||||
};
|
||||
|
||||
let mut req = self.http.post(url).json(&body);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
let bytes_stream = resp.bytes_stream();
|
||||
|
||||
// NDJSON parser: split by '\n', parse each as JSON and stream the results
|
||||
let converted_stream = bytes_stream
|
||||
.map(|result| {
|
||||
result.map_err(|e| LlmError::Http(e.to_string()))
|
||||
})
|
||||
.map_ok(|bytes| {
|
||||
// Convert the chunk to a UTF-8 string and own it
|
||||
let txt = String::from_utf8_lossy(&bytes).into_owned();
|
||||
// Parse each non-empty line into a ChatResponseChunk
|
||||
let results: Vec<Result<llm_core::StreamChunk, LlmError>> = txt
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
serde_json::from_str::<ChatResponseChunk>(trimmed)
|
||||
.map(|chunk| llm_core::StreamChunk::from(chunk))
|
||||
.map_err(|e| LlmError::Json(e.to_string())),
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
futures::stream::iter(results)
|
||||
})
|
||||
.try_flatten();
|
||||
|
||||
Ok(Box::pin(converted_stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelList {
|
||||
models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModel {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
modified_at: Option<String>,
|
||||
#[serde(default)]
|
||||
size: Option<u64>,
|
||||
#[serde(default)]
|
||||
digest: Option<String>,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
format: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OllamaClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
// Try to ping the Ollama server
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let reachable = self.http.get(&url).send().await.is_ok();
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "ollama".to_string(),
|
||||
authenticated: self.api_key.is_some(),
|
||||
account: None, // Ollama is local, no account info
|
||||
model: self.current_model.clone(),
|
||||
endpoint: self.base_url.clone(),
|
||||
reachable,
|
||||
message: if reachable {
|
||||
Some("Connected to Ollama".to_string())
|
||||
} else {
|
||||
Some("Cannot reach Ollama server".to_string())
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Ollama is a local service, no account info
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Ollama doesn't track usage statistics
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let mut req = self.http.get(&url);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
let model_list: OllamaModelList = resp.json().await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert Ollama models to ModelInfo
|
||||
let models = model_list.models.into_iter().map(|m| {
|
||||
ModelInfo {
|
||||
id: m.name.clone(),
|
||||
display_name: Some(m.name.clone()),
|
||||
description: m.details.as_ref()
|
||||
.and_then(|d| d.family.as_ref())
|
||||
.map(|f| format!("{} model", f)),
|
||||
context_window: None, // Ollama doesn't provide this in list
|
||||
max_output_tokens: None,
|
||||
supports_tools: true, // Most Ollama models support tools
|
||||
supports_vision: false, // Would need to check model capabilities
|
||||
input_price_per_mtok: None, // Local models are free
|
||||
output_price_per_mtok: None,
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
}
|
||||
13
crates/llm/ollama/src/lib.rs
Normal file
13
crates/llm/ollama/src/lib.rs
Normal file
@@ -0,0 +1,13 @@
|
||||
pub mod client;
|
||||
pub mod types;
|
||||
|
||||
pub use client::{OllamaClient, OllamaOptions, OllamaError};
|
||||
pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall};
|
||||
|
||||
// Re-export llm-core traits and types for convenience
|
||||
pub use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError,
|
||||
ChatOptions, StreamChunk, ToolCallDelta, Usage,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
Role,
|
||||
};
|
||||
130
crates/llm/ollama/src/types.rs
Normal file
130
crates/llm/ollama/src/types.rs
Normal file
@@ -0,0 +1,130 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use llm_core::{StreamChunk, ToolCallDelta};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String, // "user" | "assistant" | "system" | "tool"
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
pub call_type: Option<String>, // "function"
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: Value, // JSON object with arguments
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String, // "function"
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String, // "object"
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ChatResponseChunk {
|
||||
pub model: Option<String>,
|
||||
pub created_at: Option<String>,
|
||||
pub message: Option<ChunkMessage>,
|
||||
pub done: Option<bool>,
|
||||
pub total_duration: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ChunkMessage {
|
||||
pub role: Option<String>,
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions to/from llm-core types
|
||||
// ============================================================================
|
||||
|
||||
/// Convert from llm_core::ChatMessage to Ollama's ChatMessage
|
||||
impl From<&llm_core::ChatMessage> for ChatMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
let role = msg.role.as_str().to_string();
|
||||
|
||||
// Convert tool_calls if present
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().map(|tc| ToolCall {
|
||||
id: Some(tc.id.clone()),
|
||||
call_type: Some(tc.call_type.clone()),
|
||||
function: FunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
arguments: tc.function.arguments.clone(),
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
ChatMessage {
|
||||
role,
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Ollama's ChatResponseChunk to llm_core::StreamChunk
|
||||
impl From<ChatResponseChunk> for StreamChunk {
|
||||
fn from(chunk: ChatResponseChunk) -> Self {
|
||||
let done = chunk.done.unwrap_or(false);
|
||||
let content = chunk.message.as_ref().and_then(|m| m.content.clone());
|
||||
|
||||
// Convert tool calls to deltas
|
||||
let tool_calls = chunk.message.as_ref().and_then(|m| {
|
||||
m.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().enumerate().map(|(index, tc)| {
|
||||
// Serialize arguments back to JSON string for delta
|
||||
let arguments_delta = serde_json::to_string(&tc.function.arguments).ok();
|
||||
|
||||
ToolCallDelta {
|
||||
index,
|
||||
id: tc.id.clone(),
|
||||
function_name: Some(tc.function.name.clone()),
|
||||
arguments_delta,
|
||||
}
|
||||
}).collect()
|
||||
})
|
||||
});
|
||||
|
||||
// Ollama doesn't provide per-chunk usage stats, only in final chunk
|
||||
let usage = None;
|
||||
|
||||
StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
12
crates/llm/ollama/tests/ndjson.rs
Normal file
12
crates/llm/ollama/tests/ndjson.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use llm_ollama::{OllamaClient, OllamaOptions};
|
||||
|
||||
// This test stubs NDJSON by spinning a tiny local server is overkill for M0.
|
||||
// Instead, test the line parser indirectly by mocking reqwest is complex.
|
||||
// We'll smoke-test the client type compiles and leave end-to-end to cli tests.
|
||||
|
||||
#[tokio::test]
|
||||
async fn client_compiles_smoke() {
|
||||
let _ = OllamaClient::new("http://localhost:11434");
|
||||
let _ = OllamaClient::with_cloud();
|
||||
let _ = OllamaOptions { model: "qwen2.5".into(), stream: true };
|
||||
}
|
||||
18
crates/llm/openai/Cargo.toml
Normal file
18
crates/llm/openai/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-openai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "OpenAI GPT API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time", "io-util"] }
|
||||
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tracing = "0.1"
|
||||
285
crates/llm/openai/src/auth.rs
Normal file
285
crates/llm/openai/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with OpenAI without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for OpenAI device flow
|
||||
pub struct OpenAIAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// OpenAI OAuth endpoints
|
||||
const AUTH_BASE_URL: &str = "https://auth.openai.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl OpenAIAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OpenAIAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for OpenAIAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api.read api.write",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending),
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &OpenAIAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
561
crates/llm/openai/src/client.rs
Normal file
561
crates/llm/openai/src/client.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the Chat Completions API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
|
||||
ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_stream::wrappers::LinesStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
const CHAT_ENDPOINT: &str = "/chat/completions";
|
||||
const MODELS_ENDPOINT: &str = "/models";
|
||||
|
||||
/// OpenAI GPT API client
|
||||
pub struct OpenAIClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI format
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
|
||||
messages.iter().map(OpenAIMessage::from).collect()
|
||||
}
|
||||
|
||||
/// Convert tools to OpenAI format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
|
||||
tools.map(|t| t.iter().map(OpenAITool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAIClient {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Try to parse as error response
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
let byte_stream = response
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
|
||||
let reader = StreamReader::new(byte_stream);
|
||||
let buf_reader = tokio::io::BufReader::new(reader);
|
||||
let lines_stream = LinesStream::new(buf_reader.lines());
|
||||
|
||||
let chunk_stream = lines_stream.filter_map(|line_result| async move {
|
||||
match line_result {
|
||||
Ok(line) => parse_sse_line(&line),
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(chunk_stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Extract the first choice
|
||||
let choice = api_response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| LlmError::Api {
|
||||
message: "No choices in response".to_string(),
|
||||
code: None,
|
||||
})?;
|
||||
|
||||
let content = choice.message.content.clone();
|
||||
|
||||
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let arguments: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||
|
||||
ToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content,
|
||||
tool_calls,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single SSE line into a StreamChunk
|
||||
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SSE format: "data: <json>"
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// OpenAI sends [DONE] to signal end
|
||||
if data == "[DONE]" {
|
||||
return Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Parse the JSON chunk
|
||||
match serde_json::from_str::<ChatCompletionChunk>(data) {
|
||||
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE chunk: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI chunk to our common format
|
||||
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
|
||||
let choice = chunk.choices.first();
|
||||
|
||||
if let Some(choice) = choice {
|
||||
let content = choice.delta.content.clone();
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
|
||||
deltas
|
||||
.iter()
|
||||
.map(|delta| ToolCallDelta {
|
||||
index: delta.index,
|
||||
id: delta.id.clone(),
|
||||
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
|
||||
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let done = choice.finish_reason.is_some();
|
||||
|
||||
Ok(StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
// No choices, treat as done
|
||||
Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known GPT models with their specifications
|
||||
fn get_gpt_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "gpt-4o".to_string(),
|
||||
display_name: Some("GPT-4o".to_string()),
|
||||
description: Some("Most advanced multimodal model with vision".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(2.50),
|
||||
output_price_per_mtok: Some(10.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4o-mini".to_string(),
|
||||
display_name: Some("GPT-4o mini".to_string()),
|
||||
description: Some("Affordable and fast model for simple tasks".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.15),
|
||||
output_price_per_mtok: Some(0.60),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4-turbo".to_string(),
|
||||
display_name: Some("GPT-4 Turbo".to_string()),
|
||||
description: Some("Previous generation high-performance model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(10.0),
|
||||
output_price_per_mtok: Some(30.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
display_name: Some("GPT-3.5 Turbo".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(16_385),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: false,
|
||||
input_price_per_mtok: Some(0.50),
|
||||
output_price_per_mtok: Some(1.50),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1".to_string(),
|
||||
display_name: Some("OpenAI o1".to_string()),
|
||||
description: Some("Reasoning model optimized for complex problems".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(100_000),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(60.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1-mini".to_string(),
|
||||
display_name: Some("OpenAI o1-mini".to_string()),
|
||||
description: Some("Faster reasoning model for STEM".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(65_536),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(12.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OpenAIClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API by listing models
|
||||
let reachable = if authenticated {
|
||||
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
|
||||
let bearer = self.auth.bearer_token().unwrap();
|
||||
|
||||
match self
|
||||
.http
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach OpenAI API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "openai".to_string(),
|
||||
authenticated,
|
||||
account: None, // OpenAI doesn't expose account info via API
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// OpenAI doesn't have a public account info endpoint
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// OpenAI doesn't expose usage stats via the standard API
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// We can optionally fetch from API, but return known models for now
|
||||
Ok(get_gpt_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_gpt_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let openai_msgs = OpenAIClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(openai_msgs.len(), 3);
|
||||
assert_eq!(openai_msgs[0].role, "system");
|
||||
assert_eq!(openai_msgs[1].role, "user");
|
||||
assert_eq!(openai_msgs[2].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(openai_tools.len(), 1);
|
||||
assert_eq!(openai_tools[0].function.name, "read_file");
|
||||
assert_eq!(
|
||||
openai_tools[0].function.description,
|
||||
"Read a file's contents"
|
||||
);
|
||||
}
|
||||
}
|
||||
12
crates/llm/openai/src/lib.rs
Normal file
12
crates/llm/openai/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for OpenAI's GPT models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
285
crates/llm/openai/src/types.rs
Normal file
285
crates/llm/openai/src/types.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatCompletionRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<OpenAITool>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<&'a str>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: OpenAIFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAITool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: OpenAIFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: OpenAIMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Delta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<DeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaToolCall {
|
||||
pub index: usize,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<DeltaFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Models List Response
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for OpenAITool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: OpenAIFunction {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
parameters: FunctionParameters {
|
||||
param_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for OpenAIMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
|
||||
// Handle tool result messages
|
||||
if msg.role == Role::Tool {
|
||||
return Self {
|
||||
role: "tool".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: msg.tool_call_id.clone(),
|
||||
name: msg.name.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| OpenAIToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: serde_json::to_string(&call.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
};
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-client"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "Dedicated MCP client library for Owlen, exposing remote MCP server communication"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
@@ -1,17 +0,0 @@
|
||||
//! Owlen MCP client library.
|
||||
//!
|
||||
//! This crate provides a thin façade over the remote MCP client implementation
|
||||
//! inside `owlen-core`. It re‑exports the most useful types so downstream
|
||||
//! crates can depend only on `owlen-mcp-client` without pulling in the entire
|
||||
//! core crate internals.
|
||||
|
||||
pub use owlen_core::config::{McpConfigScope, ScopedMcpServer};
|
||||
pub use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
pub use owlen_core::mcp::{McpClient, McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
|
||||
// Re‑export the core Provider trait so that the MCP client can also be used as an LLM provider.
|
||||
pub use owlen_core::Provider as McpProvider;
|
||||
|
||||
// Note: The `RemoteMcpClient` type provides its own `new` constructor in the core
|
||||
// crate. Users can call `RemoteMcpClient::new()` directly. No additional wrapper
|
||||
// is needed here.
|
||||
@@ -1,22 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-code-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "MCP server exposing safe code execution tools for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
bollard = "0.17"
|
||||
tempfile = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "owlen_mcp_code_server"
|
||||
path = "src/lib.rs"
|
||||
@@ -1,186 +0,0 @@
|
||||
//! MCP server exposing code execution tools with Docker sandboxing.
|
||||
//!
|
||||
//! This server provides:
|
||||
//! - compile_project: Build projects (Rust, Node.js, Python)
|
||||
//! - run_tests: Execute test suites
|
||||
//! - format_code: Run code formatters
|
||||
//! - lint_code: Run linters
|
||||
|
||||
pub mod sandbox;
|
||||
pub mod tools;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
use tools::{CompileProjectTool, FormatCodeTool, LintCodeTool, RunTestsTool};
|
||||
|
||||
/// Tool registry for the code server
|
||||
#[allow(dead_code)]
|
||||
struct ToolRegistry {
|
||||
tools: HashMap<String, Box<dyn Tool + Send + Sync>>,
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl ToolRegistry {
|
||||
fn new() -> Self {
|
||||
let mut tools: HashMap<String, Box<dyn Tool + Send + Sync>> = HashMap::new();
|
||||
tools.insert(
|
||||
"compile_project".to_string(),
|
||||
Box::new(CompileProjectTool::new()),
|
||||
);
|
||||
tools.insert("run_tests".to_string(), Box::new(RunTestsTool::new()));
|
||||
tools.insert("format_code".to_string(), Box::new(FormatCodeTool::new()));
|
||||
tools.insert("lint_code".to_string(), Box::new(LintCodeTool::new()));
|
||||
Self { tools }
|
||||
}
|
||||
|
||||
fn list_tools(&self) -> Vec<owlen_core::mcp::McpToolDescriptor> {
|
||||
self.tools
|
||||
.values()
|
||||
.map(|tool| owlen_core::mcp::McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.schema(),
|
||||
requires_network: tool.requires_network(),
|
||||
requires_filesystem: tool.requires_filesystem(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn execute(&self, name: &str, args: Value) -> Result<ToolResult, String> {
|
||||
self.tools
|
||||
.get(name)
|
||||
.ok_or_else(|| format!("Tool not found: {}", name))?
|
||||
.execute(args)
|
||||
.await
|
||||
.map_err(|e| e.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
let registry = Arc::new(ToolRegistry::new());
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = handle_request(req.clone(), registry.clone()).await;
|
||||
match resp {
|
||||
Ok(r) => {
|
||||
let s = serde_json::to_string(&r)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn handle_request(
|
||||
req: RpcRequest,
|
||||
registry: Arc<ToolRegistry>,
|
||||
) -> Result<RpcResponse, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params: InitializeParams =
|
||||
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-code-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = registry.list_tools();
|
||||
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
let call = serde_json::from_value::<owlen_core::mcp::McpToolCall>(
|
||||
req.params.unwrap_or_else(|| json!({})),
|
||||
)
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||
|
||||
let result: ToolResult = registry
|
||||
.execute(&call.name, call.arguments)
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Tool execution failed: {}", e)))?;
|
||||
|
||||
let resp = owlen_core::mcp::McpToolResponse {
|
||||
name: call.name,
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
metadata: result.metadata,
|
||||
duration_ms: result.duration.as_millis() as u128,
|
||||
};
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
//! Docker-based sandboxing for secure code execution
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use bollard::Docker;
|
||||
use bollard::container::{
|
||||
Config, CreateContainerOptions, RemoveContainerOptions, StartContainerOptions,
|
||||
WaitContainerOptions,
|
||||
};
|
||||
use bollard::models::{HostConfig, Mount, MountTypeEnum};
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
||||
/// Result of executing code in a sandbox
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ExecutionResult {
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub exit_code: i64,
|
||||
pub timed_out: bool,
|
||||
}
|
||||
|
||||
/// Docker-based sandbox executor
|
||||
pub struct Sandbox {
|
||||
docker: Docker,
|
||||
memory_limit: i64,
|
||||
cpu_quota: i64,
|
||||
timeout_secs: u64,
|
||||
}
|
||||
|
||||
impl Sandbox {
|
||||
/// Create a new sandbox with default resource limits
|
||||
pub fn new() -> Result<Self> {
|
||||
let docker =
|
||||
Docker::connect_with_local_defaults().context("Failed to connect to Docker daemon")?;
|
||||
|
||||
Ok(Self {
|
||||
docker,
|
||||
memory_limit: 512 * 1024 * 1024, // 512MB
|
||||
cpu_quota: 50000, // 50% of one core
|
||||
timeout_secs: 30,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute a command in a sandboxed container
|
||||
pub async fn execute(
|
||||
&self,
|
||||
image: &str,
|
||||
cmd: &[&str],
|
||||
workspace: Option<&Path>,
|
||||
env: HashMap<String, String>,
|
||||
) -> Result<ExecutionResult> {
|
||||
let container_name = format!("owlen-sandbox-{}", uuid::Uuid::new_v4());
|
||||
|
||||
// Prepare volume mount if workspace provided
|
||||
let mounts = if let Some(ws) = workspace {
|
||||
vec![Mount {
|
||||
target: Some("/workspace".to_string()),
|
||||
source: Some(ws.to_string_lossy().to_string()),
|
||||
typ: Some(MountTypeEnum::BIND),
|
||||
read_only: Some(false),
|
||||
..Default::default()
|
||||
}]
|
||||
} else {
|
||||
vec![]
|
||||
};
|
||||
|
||||
// Create container config
|
||||
let host_config = HostConfig {
|
||||
memory: Some(self.memory_limit),
|
||||
cpu_quota: Some(self.cpu_quota),
|
||||
network_mode: Some("none".to_string()), // No network access
|
||||
mounts: Some(mounts),
|
||||
auto_remove: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let config = Config {
|
||||
image: Some(image.to_string()),
|
||||
cmd: Some(cmd.iter().map(|s| s.to_string()).collect()),
|
||||
working_dir: Some("/workspace".to_string()),
|
||||
env: Some(env.iter().map(|(k, v)| format!("{}={}", k, v)).collect()),
|
||||
host_config: Some(host_config),
|
||||
attach_stdout: Some(true),
|
||||
attach_stderr: Some(true),
|
||||
tty: Some(false),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create container
|
||||
let container = self
|
||||
.docker
|
||||
.create_container(
|
||||
Some(CreateContainerOptions {
|
||||
name: container_name.clone(),
|
||||
..Default::default()
|
||||
}),
|
||||
config,
|
||||
)
|
||||
.await
|
||||
.context("Failed to create container")?;
|
||||
|
||||
// Start container
|
||||
self.docker
|
||||
.start_container(&container.id, None::<StartContainerOptions<String>>)
|
||||
.await
|
||||
.context("Failed to start container")?;
|
||||
|
||||
// Wait for container with timeout
|
||||
let wait_result =
|
||||
tokio::time::timeout(std::time::Duration::from_secs(self.timeout_secs), async {
|
||||
let mut wait_stream = self
|
||||
.docker
|
||||
.wait_container(&container.id, None::<WaitContainerOptions<String>>);
|
||||
|
||||
use futures::StreamExt;
|
||||
if let Some(result) = wait_stream.next().await {
|
||||
result
|
||||
} else {
|
||||
Err(bollard::errors::Error::IOError {
|
||||
err: std::io::Error::other("Container wait stream ended unexpectedly"),
|
||||
})
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let (exit_code, timed_out) = match wait_result {
|
||||
Ok(Ok(result)) => (result.status_code, false),
|
||||
Ok(Err(e)) => {
|
||||
eprintln!("Container wait error: {}", e);
|
||||
(1, false)
|
||||
}
|
||||
Err(_) => {
|
||||
// Timeout - kill the container
|
||||
let _ = self
|
||||
.docker
|
||||
.kill_container(
|
||||
&container.id,
|
||||
None::<bollard::container::KillContainerOptions<String>>,
|
||||
)
|
||||
.await;
|
||||
(124, true)
|
||||
}
|
||||
};
|
||||
|
||||
// Get logs
|
||||
let logs = self.docker.logs(
|
||||
&container.id,
|
||||
Some(bollard::container::LogsOptions::<String> {
|
||||
stdout: true,
|
||||
stderr: true,
|
||||
..Default::default()
|
||||
}),
|
||||
);
|
||||
|
||||
use futures::StreamExt;
|
||||
let mut stdout = String::new();
|
||||
let mut stderr = String::new();
|
||||
|
||||
let log_result = tokio::time::timeout(std::time::Duration::from_secs(5), async {
|
||||
let mut logs = logs;
|
||||
while let Some(log) = logs.next().await {
|
||||
match log {
|
||||
Ok(bollard::container::LogOutput::StdOut { message }) => {
|
||||
stdout.push_str(&String::from_utf8_lossy(&message));
|
||||
}
|
||||
Ok(bollard::container::LogOutput::StdErr { message }) => {
|
||||
stderr.push_str(&String::from_utf8_lossy(&message));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
if log_result.is_err() {
|
||||
eprintln!("Timeout reading container logs");
|
||||
}
|
||||
|
||||
// Remove container (auto_remove should handle this, but be explicit)
|
||||
let _ = self
|
||||
.docker
|
||||
.remove_container(
|
||||
&container.id,
|
||||
Some(RemoveContainerOptions {
|
||||
force: true,
|
||||
..Default::default()
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(ExecutionResult {
|
||||
stdout,
|
||||
stderr,
|
||||
exit_code,
|
||||
timed_out,
|
||||
})
|
||||
}
|
||||
|
||||
/// Execute in a Rust environment
|
||||
pub async fn execute_rust(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("rust:1.75-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute in a Python environment
|
||||
pub async fn execute_python(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("python:3.11-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
|
||||
/// Execute in a Node.js environment
|
||||
pub async fn execute_node(&self, workspace: &Path, cmd: &[&str]) -> Result<ExecutionResult> {
|
||||
self.execute("node:20-slim", cmd, Some(workspace), HashMap::new())
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Sandbox {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create default sandbox")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[tokio::test]
|
||||
#[ignore] // Requires Docker daemon
|
||||
async fn test_sandbox_rust_compile() {
|
||||
let sandbox = Sandbox::new().unwrap();
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
// Create a simple Rust project
|
||||
std::fs::write(
|
||||
temp_dir.path().join("main.rs"),
|
||||
"fn main() { println!(\"Hello from sandbox!\"); }",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let result = sandbox
|
||||
.execute_rust(temp_dir.path(), &["rustc", "main.rs"])
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.exit_code, 0);
|
||||
assert!(!result.timed_out);
|
||||
}
|
||||
}
|
||||
@@ -1,417 +0,0 @@
|
||||
//! Code execution tools using Docker sandboxing
|
||||
|
||||
use crate::sandbox::Sandbox;
|
||||
use async_trait::async_trait;
|
||||
use owlen_core::Result;
|
||||
use owlen_core::tools::{Tool, ToolResult};
|
||||
use serde_json::{Value, json};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Tool for compiling projects (Rust, Node.js, Python)
|
||||
pub struct CompileProjectTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for CompileProjectTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl CompileProjectTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for CompileProjectTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"compile_project"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Compile a project (Rust, Node.js, Python). Detects project type automatically."
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"project_type": {
|
||||
"type": "string",
|
||||
"enum": ["rust", "node", "python"],
|
||||
"description": "Project type (auto-detected if not specified)"
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
// Detect project type
|
||||
let project_type = if let Some(pt) = args.get("project_type").and_then(|v| v.as_str()) {
|
||||
pt.to_string()
|
||||
} else if path.join("Cargo.toml").exists() {
|
||||
"rust".to_string()
|
||||
} else if path.join("package.json").exists() {
|
||||
"node".to_string()
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
"python".to_string()
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
// Execute compilation
|
||||
let result = match project_type.as_str() {
|
||||
"rust" => self.sandbox.execute_rust(&path, &["cargo", "build"]).await,
|
||||
"node" => {
|
||||
self.sandbox
|
||||
.execute_node(&path, &["npm", "run", "build"])
|
||||
.await
|
||||
}
|
||||
"python" => {
|
||||
// Python typically doesn't need compilation, but we can check syntax
|
||||
self.sandbox
|
||||
.execute_python(&path, &["python", "-m", "compileall", "."])
|
||||
.await
|
||||
}
|
||||
_ => return Ok(ToolResult::error("Unsupported project type")),
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => {
|
||||
if exec_result.timed_out {
|
||||
Ok(ToolResult::error("Compilation timed out"))
|
||||
} else if exec_result.exit_code == 0 {
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": true,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"project_type": project_type
|
||||
})))
|
||||
} else {
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": false,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"project_type": project_type
|
||||
})))
|
||||
}
|
||||
}
|
||||
Err(e) => Ok(ToolResult::error(&format!("Compilation failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for running test suites
|
||||
pub struct RunTestsTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for RunTestsTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl RunTestsTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for RunTestsTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"run_tests"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Run tests for a project (Rust, Node.js, Python)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"test_filter": {
|
||||
"type": "string",
|
||||
"description": "Optional test filter/pattern"
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let test_filter = args.get("test_filter").and_then(|v| v.as_str());
|
||||
|
||||
// Detect project type and run tests
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if let Some(filter) = test_filter {
|
||||
vec!["cargo", "test", filter]
|
||||
} else {
|
||||
vec!["cargo", "test"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
self.sandbox.execute_node(&path, &["npm", "test"]).await
|
||||
} else if path.join("pytest.ini").exists()
|
||||
|| path.join("setup.py").exists()
|
||||
|| path.join("pyproject.toml").exists()
|
||||
{
|
||||
let cmd = if let Some(filter) = test_filter {
|
||||
vec!["pytest", "-k", filter]
|
||||
} else {
|
||||
vec!["pytest"]
|
||||
};
|
||||
self.sandbox.execute_python(&path, &cmd).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect test framework"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||
"success": exec_result.exit_code == 0 && !exec_result.timed_out,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr,
|
||||
"timed_out": exec_result.timed_out
|
||||
}))),
|
||||
Err(e) => Ok(ToolResult::error(&format!("Tests failed to run: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for formatting code
|
||||
pub struct FormatCodeTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for FormatCodeTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl FormatCodeTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for FormatCodeTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"format_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Format code using project-appropriate formatter (rustfmt, prettier, black)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"check_only": {
|
||||
"type": "boolean",
|
||||
"description": "Only check formatting without modifying files",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let check_only = args
|
||||
.get("check_only")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect project type and run formatter
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["cargo", "fmt", "--", "--check"]
|
||||
} else {
|
||||
vec!["cargo", "fmt"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["npx", "prettier", "--check", "."]
|
||||
} else {
|
||||
vec!["npx", "prettier", "--write", "."]
|
||||
};
|
||||
self.sandbox.execute_node(&path, &cmd).await
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
let cmd = if check_only {
|
||||
vec!["black", "--check", "."]
|
||||
} else {
|
||||
vec!["black", "."]
|
||||
};
|
||||
self.sandbox.execute_python(&path, &cmd).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => Ok(ToolResult::success(json!({
|
||||
"success": exec_result.exit_code == 0,
|
||||
"formatted": !check_only && exec_result.exit_code == 0,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr
|
||||
}))),
|
||||
Err(e) => Ok(ToolResult::error(&format!("Formatting failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool for linting code
|
||||
pub struct LintCodeTool {
|
||||
sandbox: Sandbox,
|
||||
}
|
||||
|
||||
impl Default for LintCodeTool {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl LintCodeTool {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
sandbox: Sandbox::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for LintCodeTool {
|
||||
fn name(&self) -> &'static str {
|
||||
"lint_code"
|
||||
}
|
||||
|
||||
fn description(&self) -> &'static str {
|
||||
"Lint code using project-appropriate linter (clippy, eslint, pylint)"
|
||||
}
|
||||
|
||||
fn schema(&self) -> Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"project_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the project root"
|
||||
},
|
||||
"fix": {
|
||||
"type": "boolean",
|
||||
"description": "Automatically fix issues if possible",
|
||||
"default": false
|
||||
}
|
||||
},
|
||||
"required": ["project_path"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: Value) -> Result<ToolResult> {
|
||||
let project_path = args
|
||||
.get("project_path")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| owlen_core::Error::InvalidInput("Missing project_path".into()))?;
|
||||
|
||||
let path = PathBuf::from(project_path);
|
||||
if !path.exists() {
|
||||
return Ok(ToolResult::error("Project path does not exist"));
|
||||
}
|
||||
|
||||
let fix = args.get("fix").and_then(|v| v.as_bool()).unwrap_or(false);
|
||||
|
||||
// Detect project type and run linter
|
||||
let result = if path.join("Cargo.toml").exists() {
|
||||
let cmd = if fix {
|
||||
vec!["cargo", "clippy", "--fix", "--allow-dirty"]
|
||||
} else {
|
||||
vec!["cargo", "clippy"]
|
||||
};
|
||||
self.sandbox.execute_rust(&path, &cmd).await
|
||||
} else if path.join("package.json").exists() {
|
||||
let cmd = if fix {
|
||||
vec!["npx", "eslint", ".", "--fix"]
|
||||
} else {
|
||||
vec!["npx", "eslint", "."]
|
||||
};
|
||||
self.sandbox.execute_node(&path, &cmd).await
|
||||
} else if path.join("setup.py").exists() || path.join("pyproject.toml").exists() {
|
||||
// pylint doesn't have auto-fix
|
||||
self.sandbox.execute_python(&path, &["pylint", "."]).await
|
||||
} else {
|
||||
return Ok(ToolResult::error("Could not detect project type"));
|
||||
};
|
||||
|
||||
match result {
|
||||
Ok(exec_result) => {
|
||||
let issues_found = exec_result.exit_code != 0;
|
||||
Ok(ToolResult::success(json!({
|
||||
"success": true,
|
||||
"issues_found": issues_found,
|
||||
"exit_code": exec_result.exit_code,
|
||||
"stdout": exec_result.stdout,
|
||||
"stderr": exec_result.stderr
|
||||
})))
|
||||
}
|
||||
Err(e) => Ok(ToolResult::error(&format!("Linting failed: {}", e))),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-llm-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
tokio-stream = { workspace = true }
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-mcp-llm-server"
|
||||
path = "src/main.rs"
|
||||
@@ -1,597 +0,0 @@
|
||||
#![allow(
|
||||
unused_imports,
|
||||
unused_variables,
|
||||
dead_code,
|
||||
clippy::unnecessary_cast,
|
||||
clippy::manual_flatten,
|
||||
clippy::empty_line_after_outer_attr
|
||||
)]
|
||||
|
||||
use owlen_core::Provider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{Config as OwlenConfig, ensure_provider_config};
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcNotification, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo,
|
||||
methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::types::{ChatParameters, ChatRequest, Message};
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::sync::Arc;
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
use tokio_stream::StreamExt;
|
||||
|
||||
// Suppress warnings are handled by the crate-level attribute at the top.
|
||||
|
||||
/// Arguments for the generate_text tool
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GenerateTextArgs {
|
||||
messages: Vec<Message>,
|
||||
temperature: Option<f32>,
|
||||
max_tokens: Option<u32>,
|
||||
model: String,
|
||||
stream: bool,
|
||||
}
|
||||
|
||||
/// Simple tool descriptor for generate_text
|
||||
fn generate_text_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "generate_text".to_string(),
|
||||
description: "Generate text using Ollama LLM. Each message must have 'role' (user/assistant/system) and 'content' (string) fields.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"messages": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"role": {
|
||||
"type": "string",
|
||||
"enum": ["user", "assistant", "system"],
|
||||
"description": "The role of the message sender"
|
||||
},
|
||||
"content": {
|
||||
"type": "string",
|
||||
"description": "The message content"
|
||||
}
|
||||
},
|
||||
"required": ["role", "content"]
|
||||
},
|
||||
"description": "Array of message objects with role and content"
|
||||
},
|
||||
"temperature": {"type": ["number", "null"], "description": "Sampling temperature (0.0-2.0)"},
|
||||
"max_tokens": {"type": ["integer", "null"], "description": "Maximum tokens to generate"},
|
||||
"model": {"type": "string", "description": "Model name (e.g., llama3.2:latest)"},
|
||||
"stream": {"type": "boolean", "description": "Whether to stream the response"}
|
||||
},
|
||||
"required": ["messages", "model", "stream"]
|
||||
}),
|
||||
requires_network: true,
|
||||
requires_filesystem: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool descriptor for resources/get (read file)
|
||||
fn resources_get_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "resources/get".to_string(),
|
||||
description: "Read and return the TEXT CONTENTS of a single FILE. Use this to read the contents of code files, config files, or text documents. Do NOT use for directories.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the FILE (not directory) to read"}
|
||||
},
|
||||
"required": ["path"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec!["read".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
/// Tool descriptor for resources/list (list directory)
|
||||
fn resources_list_descriptor() -> McpToolDescriptor {
|
||||
McpToolDescriptor {
|
||||
name: "resources/list".to_string(),
|
||||
description: "List the NAMES of all files and directories in a directory. Use this to see what files exist in a folder, or to list directory contents. Returns an array of file/directory names.".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string", "description": "Path to the DIRECTORY to list (use '.' for current directory)"}
|
||||
}
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec!["read".to_string()],
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_from_config() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
let mut config = OwlenConfig::load(None).unwrap_or_default();
|
||||
let requested_name =
|
||||
env::var("OWLEN_PROVIDER").unwrap_or_else(|_| config.general.default_provider.clone());
|
||||
let provider_key = canonical_provider_name(&requested_name);
|
||||
if config.provider(&provider_key).is_none() {
|
||||
ensure_provider_config(&mut config, &provider_key);
|
||||
}
|
||||
let provider_cfg: ProviderConfig =
|
||||
config.provider(&provider_key).cloned().ok_or_else(|| {
|
||||
RpcError::internal_error(format!(
|
||||
"Provider '{provider_key}' not found in configuration"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(&provider_cfg, Some(&config.general))
|
||||
.map_err(|e| {
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to init Ollama provider from config: {e}"
|
||||
))
|
||||
})?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(RpcError::internal_error(format!(
|
||||
"Unsupported provider type '{other}' for MCP LLM server"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
fn create_provider() -> Result<Arc<dyn Provider>, RpcError> {
|
||||
if let Ok(url) = env::var("OLLAMA_URL") {
|
||||
let provider = OllamaProvider::new(&url).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to init Ollama provider: {e}"))
|
||||
})?;
|
||||
return Ok(Arc::new(provider) as Arc<dyn Provider>);
|
||||
}
|
||||
|
||||
provider_from_config()
|
||||
}
|
||||
|
||||
fn canonical_provider_name(name: &str) -> String {
|
||||
let normalized = name.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => "ollama_local".to_string(),
|
||||
"ollama" | "ollama_local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" => "ollama_cloud".to_string(),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_generate_text(args: GenerateTextArgs) -> Result<String, RpcError> {
|
||||
let provider = create_provider()?;
|
||||
|
||||
let parameters = ChatParameters {
|
||||
temperature: args.temperature,
|
||||
max_tokens: args.max_tokens.map(|v| v as u32),
|
||||
stream: args.stream,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
|
||||
let request = ChatRequest {
|
||||
model: args.model,
|
||||
messages: args.messages,
|
||||
parameters,
|
||||
tools: None,
|
||||
};
|
||||
|
||||
// Use streaming API and collect output
|
||||
let mut stream = provider
|
||||
.stream_prompt(request)
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Chat request failed: {}", e)))?;
|
||||
let mut content = String::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(resp) => {
|
||||
content.push_str(&resp.message.content);
|
||||
if resp.is_final {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(RpcError::internal_error(format!("Stream error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn handle_request(req: &RpcRequest) -> Result<Value, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params for initialize"))?;
|
||||
let init: InitializeParams = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !init.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
init.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-llm-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(true),
|
||||
},
|
||||
};
|
||||
serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize init result: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
generate_text_descriptor(),
|
||||
resources_get_descriptor(),
|
||||
resources_list_descriptor(),
|
||||
];
|
||||
Ok(json!(tools))
|
||||
}
|
||||
// New method to list available Ollama models via the provider.
|
||||
methods::MODELS_LIST => {
|
||||
let provider = create_provider()?;
|
||||
let models = provider
|
||||
.list_models()
|
||||
.await
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to list models: {}", e)))?;
|
||||
serde_json::to_value(models).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize model list: {}", e))
|
||||
})
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
// For streaming we will send incremental notifications directly from here.
|
||||
// The caller (main loop) will handle writing the final response.
|
||||
Err(RpcError::internal_error(
|
||||
"TOOLS_CALL should be handled in main loop for streaming",
|
||||
))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let root = env::current_dir()?; // not used but kept for parity
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break,
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let id = req.id.clone();
|
||||
// Streaming tool calls (generate_text) are handled specially to emit incremental notifications.
|
||||
if req.method == methods::TOOLS_CALL {
|
||||
// Parse the tool call
|
||||
let params = match &req.params {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params("Missing params for tool call"),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let call: McpToolCall = match serde_json::from_value(params.clone()) {
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params(format!("Invalid tool call: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Dispatch based on the requested tool name.
|
||||
// Handle resources tools manually.
|
||||
if call.name.starts_with("resources/get") {
|
||||
let path = call
|
||||
.arguments
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("");
|
||||
match std::fs::read_to_string(path) {
|
||||
Ok(content) => {
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(content),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize resource response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Failed to read file: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
if call.name.starts_with("resources/list") {
|
||||
let path = call
|
||||
.arguments
|
||||
.get("path")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or(".");
|
||||
match std::fs::read_dir(path) {
|
||||
Ok(entries) => {
|
||||
let mut names = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
if let Some(name) = entry.file_name().to_str() {
|
||||
names.push(name.to_string());
|
||||
}
|
||||
}
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(names),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize directory listing: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Failed to list dir: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Expect generate_text tool for the remaining path.
|
||||
if call.name != "generate_text" {
|
||||
let err_resp =
|
||||
RpcErrorResponse::new(id.clone(), RpcError::tool_not_found(&call.name));
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
let args: GenerateTextArgs =
|
||||
match serde_json::from_value(call.arguments.clone()) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::invalid_params(format!("Invalid arguments: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize provider and start streaming
|
||||
let provider = match create_provider() {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to initialize provider: {:?}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let parameters = ChatParameters {
|
||||
temperature: args.temperature,
|
||||
max_tokens: args.max_tokens.map(|v| v as u32),
|
||||
stream: true,
|
||||
extra: HashMap::new(),
|
||||
};
|
||||
let request = ChatRequest {
|
||||
model: args.model,
|
||||
messages: args.messages,
|
||||
parameters,
|
||||
tools: None,
|
||||
};
|
||||
let mut stream = match provider.stream_prompt(request).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Chat request failed: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
// Accumulate full content while sending incremental progress notifications
|
||||
let mut final_content = String::new();
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(resp) => {
|
||||
// Append chunk to the final content buffer
|
||||
final_content.push_str(&resp.message.content);
|
||||
// Emit a progress notification for the UI
|
||||
let notif = RpcNotification::new(
|
||||
"tools/call/progress",
|
||||
Some(json!({ "content": resp.message.content })),
|
||||
);
|
||||
let s = serde_json::to_string(¬if)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
if resp.is_final {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!("Stream error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// After streaming, send the final tool response containing the full content
|
||||
let final_output = final_content.clone();
|
||||
let response = McpToolResponse {
|
||||
name: call.name,
|
||||
success: true,
|
||||
output: json!(final_output),
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
let payload = match serde_json::to_value(&response) {
|
||||
Ok(value) => value,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
id.clone(),
|
||||
RpcError::internal_error(format!(
|
||||
"Failed to serialize final streaming response: {}",
|
||||
e
|
||||
)),
|
||||
);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let final_resp = RpcResponse::new(id.clone(), payload);
|
||||
let s = serde_json::to_string(&final_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
// Non‑streaming requests are handled by the generic handler
|
||||
match handle_request(&req).await {
|
||||
Ok(res) => {
|
||||
let resp = RpcResponse::new(id, res);
|
||||
let s = serde_json::to_string(&resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(err) => {
|
||||
let err_resp = RpcErrorResponse::new(id, err);
|
||||
let s = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Read error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-prompt-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
description = "MCP server that renders prompt templates (YAML) for Owlen"
|
||||
license = "AGPL-3.0"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
handlebars = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
[lib]
|
||||
name = "owlen_mcp_prompt_server"
|
||||
path = "src/lib.rs"
|
||||
@@ -1,415 +0,0 @@
|
||||
//! MCP server for rendering prompt templates with YAML storage and Handlebars rendering.
|
||||
//!
|
||||
//! Templates are stored in `~/.config/owlen/prompts/` as YAML files.
|
||||
//! Provides full Handlebars templating support for dynamic prompt generation.
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use handlebars::Handlebars;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Value, json};
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, methods,
|
||||
};
|
||||
use owlen_core::mcp::{McpToolCall, McpToolDescriptor, McpToolResponse};
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
/// Prompt template definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PromptTemplate {
|
||||
/// Template name
|
||||
pub name: String,
|
||||
/// Template version
|
||||
pub version: String,
|
||||
/// Optional mode restriction
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mode: Option<String>,
|
||||
/// Handlebars template content
|
||||
pub template: String,
|
||||
/// Template description
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
/// Prompt server managing templates
|
||||
pub struct PromptServer {
|
||||
templates: Arc<RwLock<HashMap<String, PromptTemplate>>>,
|
||||
handlebars: Handlebars<'static>,
|
||||
templates_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PromptServer {
|
||||
/// Create a new prompt server
|
||||
pub fn new() -> Result<Self> {
|
||||
let templates_dir = Self::get_templates_dir()?;
|
||||
|
||||
// Create templates directory if it doesn't exist
|
||||
if !templates_dir.exists() {
|
||||
fs::create_dir_all(&templates_dir)?;
|
||||
Self::create_default_templates(&templates_dir)?;
|
||||
}
|
||||
|
||||
let mut server = Self {
|
||||
templates: Arc::new(RwLock::new(HashMap::new())),
|
||||
handlebars: Handlebars::new(),
|
||||
templates_dir,
|
||||
};
|
||||
|
||||
// Load all templates
|
||||
server.load_templates()?;
|
||||
|
||||
Ok(server)
|
||||
}
|
||||
|
||||
/// Get the templates directory path
|
||||
fn get_templates_dir() -> Result<PathBuf> {
|
||||
let config_dir = dirs::config_dir().context("Could not determine config directory")?;
|
||||
Ok(config_dir.join("owlen").join("prompts"))
|
||||
}
|
||||
|
||||
/// Create default template examples
|
||||
fn create_default_templates(dir: &Path) -> Result<()> {
|
||||
let chat_mode_system = PromptTemplate {
|
||||
name: "chat_mode_system".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
mode: Some("chat".to_string()),
|
||||
description: Some("System prompt for chat mode".to_string()),
|
||||
template: r#"You are Owlen, a helpful AI assistant. You have access to these tools:
|
||||
{{#each tools}}
|
||||
- {{name}}: {{description}}
|
||||
{{/each}}
|
||||
|
||||
Use the ReAct pattern:
|
||||
THOUGHT: Your reasoning
|
||||
ACTION: tool_name
|
||||
ACTION_INPUT: {"param": "value"}
|
||||
|
||||
When you have enough information:
|
||||
FINAL_ANSWER: Your response"#
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
let code_mode_system = PromptTemplate {
|
||||
name: "code_mode_system".to_string(),
|
||||
version: "1.0".to_string(),
|
||||
mode: Some("code".to_string()),
|
||||
description: Some("System prompt for code mode".to_string()),
|
||||
template: r#"You are Owlen in code mode, with full development capabilities. You have access to:
|
||||
{{#each tools}}
|
||||
- {{name}}: {{description}}
|
||||
{{/each}}
|
||||
|
||||
Use the ReAct pattern to solve coding tasks:
|
||||
THOUGHT: Analyze what needs to be done
|
||||
ACTION: tool_name (compile_project, run_tests, format_code, lint_code, etc.)
|
||||
ACTION_INPUT: {"param": "value"}
|
||||
|
||||
Continue iterating until the task is complete, then provide:
|
||||
FINAL_ANSWER: Summary of what was done"#
|
||||
.to_string(),
|
||||
};
|
||||
|
||||
// Save templates
|
||||
let chat_path = dir.join("chat_mode_system.yaml");
|
||||
let code_path = dir.join("code_mode_system.yaml");
|
||||
|
||||
fs::write(chat_path, serde_yaml::to_string(&chat_mode_system)?)?;
|
||||
fs::write(code_path, serde_yaml::to_string(&code_mode_system)?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load all templates from the templates directory
|
||||
fn load_templates(&mut self) -> Result<()> {
|
||||
let entries = fs::read_dir(&self.templates_dir)?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("yaml")
|
||||
|| path.extension().and_then(|s| s.to_str()) == Some("yml")
|
||||
{
|
||||
match self.load_template(&path) {
|
||||
Ok(template) => {
|
||||
// Register with Handlebars
|
||||
if let Err(e) = self
|
||||
.handlebars
|
||||
.register_template_string(&template.name, &template.template)
|
||||
{
|
||||
eprintln!(
|
||||
"Warning: Failed to register template {}: {}",
|
||||
template.name, e
|
||||
);
|
||||
} else {
|
||||
let mut templates = self.templates.blocking_write();
|
||||
templates.insert(template.name.clone(), template);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Warning: Failed to load template {:?}: {}", path, e);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a single template from file
|
||||
fn load_template(&self, path: &Path) -> Result<PromptTemplate> {
|
||||
let content = fs::read_to_string(path)?;
|
||||
let template: PromptTemplate = serde_yaml::from_str(&content)?;
|
||||
Ok(template)
|
||||
}
|
||||
|
||||
/// Get a template by name
|
||||
pub async fn get_template(&self, name: &str) -> Option<PromptTemplate> {
|
||||
let templates = self.templates.read().await;
|
||||
templates.get(name).cloned()
|
||||
}
|
||||
|
||||
/// List all available templates
|
||||
pub async fn list_templates(&self) -> Vec<String> {
|
||||
let templates = self.templates.read().await;
|
||||
templates.keys().cloned().collect()
|
||||
}
|
||||
|
||||
/// Render a template with given variables
|
||||
pub fn render_template(&self, name: &str, vars: &Value) -> Result<String> {
|
||||
self.handlebars
|
||||
.render(name, vars)
|
||||
.context("Failed to render template")
|
||||
}
|
||||
|
||||
/// Reload all templates from disk
|
||||
pub async fn reload_templates(&mut self) -> Result<()> {
|
||||
{
|
||||
let mut templates = self.templates.write().await;
|
||||
templates.clear();
|
||||
}
|
||||
self.handlebars = Handlebars::new();
|
||||
self.load_templates()
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
let server = Arc::new(tokio::sync::Mutex::new(PromptServer::new()?));
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => break, // EOF
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let resp = handle_request(req.clone(), server.clone()).await;
|
||||
match resp {
|
||||
Ok(r) => {
|
||||
let s = serde_json::to_string(&r)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(e) => {
|
||||
let err = RpcErrorResponse::new(req.id.clone(), e);
|
||||
let s = serde_json::to_string(&err)?;
|
||||
stdout.write_all(s.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error reading stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
async fn handle_request(
|
||||
req: RpcRequest,
|
||||
server: Arc<tokio::sync::Mutex<PromptServer>>,
|
||||
) -> Result<RpcResponse, RpcError> {
|
||||
match req.method.as_str() {
|
||||
methods::INITIALIZE => {
|
||||
let params: InitializeParams =
|
||||
serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid init params: {}", e)))?;
|
||||
if !params.protocol_version.eq(PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-prompt-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(true),
|
||||
supports_resources: Some(false),
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
let payload = serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize initialize result: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
methods::TOOLS_LIST => {
|
||||
let tools = vec![
|
||||
McpToolDescriptor {
|
||||
name: "get_prompt".to_string(),
|
||||
description: "Retrieve a prompt template by name".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Template name"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "render_prompt".to_string(),
|
||||
description: "Render a prompt template with Handlebars variables".to_string(),
|
||||
input_schema: json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "Template name"},
|
||||
"vars": {"type": "object", "description": "Variables for Handlebars rendering"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "list_prompts".to_string(),
|
||||
description: "List all available prompt templates".to_string(),
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
McpToolDescriptor {
|
||||
name: "reload_prompts".to_string(),
|
||||
description: "Reload all prompts from disk".to_string(),
|
||||
input_schema: json!({"type": "object", "properties": {}}),
|
||||
requires_network: false,
|
||||
requires_filesystem: vec![],
|
||||
},
|
||||
];
|
||||
Ok(RpcResponse::new(req.id, json!(tools)))
|
||||
}
|
||||
methods::TOOLS_CALL => {
|
||||
let call: McpToolCall = serde_json::from_value(req.params.unwrap_or_else(|| json!({})))
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid tool call: {}", e)))?;
|
||||
|
||||
let result = match call.name.as_str() {
|
||||
"get_prompt" => {
|
||||
let name = call
|
||||
.arguments
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||
|
||||
let srv = server.lock().await;
|
||||
match srv.get_template(name).await {
|
||||
Some(template) => match serde_json::to_value(template) {
|
||||
Ok(serialized) => {
|
||||
json!({"success": true, "template": serialized})
|
||||
}
|
||||
Err(e) => {
|
||||
return Err(RpcError::internal_error(format!(
|
||||
"Failed to serialize template '{}': {}",
|
||||
name, e
|
||||
)));
|
||||
}
|
||||
},
|
||||
None => json!({"success": false, "error": "Template not found"}),
|
||||
}
|
||||
}
|
||||
"render_prompt" => {
|
||||
let name = call
|
||||
.arguments
|
||||
.get("name")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing 'name' parameter"))?;
|
||||
|
||||
let default_vars = json!({});
|
||||
let vars = call.arguments.get("vars").unwrap_or(&default_vars);
|
||||
|
||||
let srv = server.lock().await;
|
||||
match srv.render_template(name, vars) {
|
||||
Ok(rendered) => json!({"success": true, "rendered": rendered}),
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
}
|
||||
"list_prompts" => {
|
||||
let srv = server.lock().await;
|
||||
let templates = srv.list_templates().await;
|
||||
json!({"success": true, "templates": templates})
|
||||
}
|
||||
"reload_prompts" => {
|
||||
let mut srv = server.lock().await;
|
||||
match srv.reload_templates().await {
|
||||
Ok(_) => json!({"success": true, "message": "Prompts reloaded"}),
|
||||
Err(e) => json!({"success": false, "error": e.to_string()}),
|
||||
}
|
||||
}
|
||||
_ => return Err(RpcError::method_not_found(&call.name)),
|
||||
};
|
||||
|
||||
let resp = McpToolResponse {
|
||||
name: call.name,
|
||||
success: result
|
||||
.get("success")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false),
|
||||
output: result,
|
||||
metadata: HashMap::new(),
|
||||
duration_ms: 0,
|
||||
};
|
||||
|
||||
let payload = serde_json::to_value(resp).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize tool response: {}", e))
|
||||
})?;
|
||||
Ok(RpcResponse::new(req.id, payload))
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
prompt: |
|
||||
Hello {{name}}!
|
||||
Your role is: {{role}}.
|
||||
@@ -1,12 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-mcp-server"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
|
||||
[dependencies]
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
anyhow = { workspace = true }
|
||||
path-clean = "1.0"
|
||||
owlen-core = { path = "../../owlen-core" }
|
||||
@@ -1,246 +0,0 @@
|
||||
use owlen_core::mcp::protocol::{
|
||||
ErrorCode, InitializeParams, InitializeResult, PROTOCOL_VERSION, RequestId, RpcError,
|
||||
RpcErrorResponse, RpcRequest, RpcResponse, ServerCapabilities, ServerInfo, is_compatible,
|
||||
};
|
||||
use path_clean::PathClean;
|
||||
use serde::Deserialize;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct FileArgs {
|
||||
path: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct WriteArgs {
|
||||
path: String,
|
||||
content: String,
|
||||
}
|
||||
|
||||
async fn handle_request(req: &RpcRequest, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
match req.method.as_str() {
|
||||
"initialize" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params for initialize"))?;
|
||||
|
||||
let init_params: InitializeParams =
|
||||
serde_json::from_value(params.clone()).map_err(|e| {
|
||||
RpcError::invalid_params(format!("Invalid initialize params: {}", e))
|
||||
})?;
|
||||
|
||||
// Check protocol version compatibility
|
||||
if !is_compatible(&init_params.protocol_version, PROTOCOL_VERSION) {
|
||||
return Err(RpcError::new(
|
||||
ErrorCode::INVALID_REQUEST,
|
||||
format!(
|
||||
"Incompatible protocol version. Client: {}, Server: {}",
|
||||
init_params.protocol_version, PROTOCOL_VERSION
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
// Build initialization result
|
||||
let result = InitializeResult {
|
||||
protocol_version: PROTOCOL_VERSION.to_string(),
|
||||
server_info: ServerInfo {
|
||||
name: "owlen-mcp-server".to_string(),
|
||||
version: env!("CARGO_PKG_VERSION").to_string(),
|
||||
},
|
||||
capabilities: ServerCapabilities {
|
||||
supports_tools: Some(false),
|
||||
supports_resources: Some(true), // Supports read, write, delete
|
||||
supports_streaming: Some(false),
|
||||
},
|
||||
};
|
||||
|
||||
Ok(serde_json::to_value(result).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to serialize result: {}", e))
|
||||
})?)
|
||||
}
|
||||
"resources/list" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_list(&args.path, root).await
|
||||
}
|
||||
"resources/get" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_get(&args.path, root).await
|
||||
}
|
||||
"resources/write" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: WriteArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_write(&args.path, &args.content, root).await
|
||||
}
|
||||
"resources/delete" => {
|
||||
let params = req
|
||||
.params
|
||||
.as_ref()
|
||||
.ok_or_else(|| RpcError::invalid_params("Missing params"))?;
|
||||
let args: FileArgs = serde_json::from_value(params.clone())
|
||||
.map_err(|e| RpcError::invalid_params(format!("Invalid params: {}", e)))?;
|
||||
resources_delete(&args.path, root).await
|
||||
}
|
||||
_ => Err(RpcError::method_not_found(&req.method)),
|
||||
}
|
||||
}
|
||||
|
||||
fn sanitize_path(path: &str, root: &Path) -> Result<PathBuf, RpcError> {
|
||||
let path = Path::new(path);
|
||||
let path = if path.is_absolute() {
|
||||
path.strip_prefix("/")
|
||||
.map_err(|_| RpcError::invalid_params("Invalid path"))?
|
||||
.to_path_buf()
|
||||
} else {
|
||||
path.to_path_buf()
|
||||
};
|
||||
|
||||
let full_path = root.join(path).clean();
|
||||
|
||||
if !full_path.starts_with(root) {
|
||||
return Err(RpcError::path_traversal());
|
||||
}
|
||||
|
||||
Ok(full_path)
|
||||
}
|
||||
|
||||
async fn resources_list(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
|
||||
let entries = fs::read_dir(full_path).map_err(|e| {
|
||||
RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
format!("Failed to read directory: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
let mut result = Vec::new();
|
||||
for entry in entries {
|
||||
let entry = entry.map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to read directory entry: {}", e))
|
||||
})?;
|
||||
result.push(entry.file_name().to_string_lossy().to_string());
|
||||
}
|
||||
|
||||
Ok(serde_json::json!(result))
|
||||
}
|
||||
|
||||
async fn resources_get(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
|
||||
let content = fs::read_to_string(full_path).map_err(|e| {
|
||||
RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
format!("Failed to read file: {}", e),
|
||||
)
|
||||
})?;
|
||||
|
||||
Ok(serde_json::json!(content))
|
||||
}
|
||||
|
||||
async fn resources_write(
|
||||
path: &str,
|
||||
content: &str,
|
||||
root: &Path,
|
||||
) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = full_path.parent() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
RpcError::internal_error(format!("Failed to create parent directories: {}", e))
|
||||
})?;
|
||||
}
|
||||
std::fs::write(full_path, content)
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to write file: {}", e)))?;
|
||||
Ok(serde_json::json!(null))
|
||||
}
|
||||
|
||||
async fn resources_delete(path: &str, root: &Path) -> Result<serde_json::Value, RpcError> {
|
||||
let full_path = sanitize_path(path, root)?;
|
||||
if full_path.is_file() {
|
||||
std::fs::remove_file(full_path)
|
||||
.map_err(|e| RpcError::internal_error(format!("Failed to delete file: {}", e)))?;
|
||||
Ok(serde_json::json!(null))
|
||||
} else {
|
||||
Err(RpcError::new(
|
||||
ErrorCode::RESOURCE_NOT_FOUND,
|
||||
"Path does not refer to a file",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let root = env::current_dir()?;
|
||||
let mut stdin = io::BufReader::new(io::stdin());
|
||||
let mut stdout = io::stdout();
|
||||
|
||||
loop {
|
||||
let mut line = String::new();
|
||||
match stdin.read_line(&mut line).await {
|
||||
Ok(0) => {
|
||||
// EOF
|
||||
break;
|
||||
}
|
||||
Ok(_) => {
|
||||
let req: RpcRequest = match serde_json::from_str(&line) {
|
||||
Ok(req) => req,
|
||||
Err(e) => {
|
||||
let err_resp = RpcErrorResponse::new(
|
||||
RequestId::Number(0),
|
||||
RpcError::parse_error(format!("Parse error: {}", e)),
|
||||
);
|
||||
let resp_str = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let request_id = req.id.clone();
|
||||
|
||||
match handle_request(&req, &root).await {
|
||||
Ok(result) => {
|
||||
let resp = RpcResponse::new(request_id, result);
|
||||
let resp_str = serde_json::to_string(&resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
Err(error) => {
|
||||
let err_resp = RpcErrorResponse::new(request_id, error);
|
||||
let resp_str = serde_json::to_string(&err_resp)?;
|
||||
stdout.write_all(resp_str.as_bytes()).await?;
|
||||
stdout.write_all(b"\n").await?;
|
||||
stdout.flush().await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Handle read error
|
||||
eprintln!("Error reading from stdin: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,60 +0,0 @@
|
||||
[package]
|
||||
name = "owlen-cli"
|
||||
version.workspace = true
|
||||
edition.workspace = true
|
||||
authors.workspace = true
|
||||
license.workspace = true
|
||||
repository.workspace = true
|
||||
homepage.workspace = true
|
||||
description = "Command-line interface for OWLEN LLM client"
|
||||
|
||||
[features]
|
||||
default = ["chat-client"]
|
||||
chat-client = ["owlen-tui"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen"
|
||||
path = "src/main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-code"
|
||||
path = "src/code_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[[bin]]
|
||||
name = "owlen-agent"
|
||||
path = "src/agent_main.rs"
|
||||
required-features = ["chat-client"]
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-providers = { path = "../owlen-providers" }
|
||||
# Optional TUI dependency, enabled by the "chat-client" feature.
|
||||
owlen-tui = { path = "../owlen-tui", optional = true }
|
||||
log = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
|
||||
# CLI framework
|
||||
clap = { workspace = true, features = ["derive"] }
|
||||
|
||||
# Async runtime
|
||||
tokio = { workspace = true }
|
||||
tokio-util = { workspace = true }
|
||||
|
||||
# TUI framework
|
||||
ratatui = { workspace = true }
|
||||
crossterm = { workspace = true }
|
||||
|
||||
# Utilities
|
||||
anyhow = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
tokio-test = { workspace = true }
|
||||
@@ -1,15 +0,0 @@
|
||||
# Owlen CLI
|
||||
|
||||
This crate is the command-line entry point for the Owlen application.
|
||||
|
||||
It is responsible for:
|
||||
|
||||
- Parsing command-line arguments.
|
||||
- Loading the configuration.
|
||||
- Initializing the providers.
|
||||
- Starting the `owlen-tui` application.
|
||||
|
||||
There are two binaries:
|
||||
|
||||
- `owlen`: The main chat application.
|
||||
- `owlen-code`: A specialized version for code-related tasks.
|
||||
@@ -1,31 +0,0 @@
|
||||
use std::process::Command;
|
||||
|
||||
fn main() {
|
||||
const MIN_VERSION: (u32, u32, u32) = (1, 75, 0);
|
||||
|
||||
let rustc = std::env::var("RUSTC").unwrap_or_else(|_| "rustc".into());
|
||||
let output = Command::new(&rustc)
|
||||
.arg("--version")
|
||||
.output()
|
||||
.expect("failed to invoke rustc");
|
||||
|
||||
let version_line = String::from_utf8_lossy(&output.stdout);
|
||||
let version_str = version_line.split_whitespace().nth(1).unwrap_or("0.0.0");
|
||||
let sanitized = version_str.split('-').next().unwrap_or(version_str);
|
||||
|
||||
let mut parts = sanitized
|
||||
.split('.')
|
||||
.map(|part| part.parse::<u32>().unwrap_or(0));
|
||||
let current = (
|
||||
parts.next().unwrap_or(0),
|
||||
parts.next().unwrap_or(0),
|
||||
parts.next().unwrap_or(0),
|
||||
);
|
||||
|
||||
if current < MIN_VERSION {
|
||||
panic!(
|
||||
"owlen requires rustc {}.{}.{} or newer (found {version_line})",
|
||||
MIN_VERSION.0, MIN_VERSION.1, MIN_VERSION.2
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
//! Simple entry point for the ReAct agentic executor.
|
||||
//!
|
||||
//! Usage: `owlen-agent "<prompt>" [--model <model>] [--max-iter <n>]`
|
||||
//!
|
||||
//! This binary demonstrates Phase 4 without the full TUI. It creates an
|
||||
//! OllamaProvider, a RemoteMcpClient, runs the AgentExecutor and prints the
|
||||
//! final answer.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use clap::Parser;
|
||||
use owlen_cli::agent::{AgentConfig, AgentExecutor};
|
||||
use owlen_core::mcp::remote_client::RemoteMcpClient;
|
||||
|
||||
/// Command‑line arguments for the agent binary.
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(
|
||||
name = "owlen-agent",
|
||||
author,
|
||||
version,
|
||||
about = "Run the ReAct agent via MCP"
|
||||
)]
|
||||
struct Args {
|
||||
/// The initial user query.
|
||||
prompt: String,
|
||||
/// Model to use (defaults to Ollama default).
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
/// Maximum ReAct iterations.
|
||||
#[arg(long, default_value_t = 10)]
|
||||
max_iter: usize,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
// Initialise the MCP LLM client – it implements Provider and talks to the
|
||||
// MCP LLM server which wraps Ollama. This ensures all communication goes
|
||||
// through the MCP architecture (Phase 10 requirement).
|
||||
let provider = Arc::new(RemoteMcpClient::new()?);
|
||||
|
||||
// The MCP client also serves as the tool client for resource operations
|
||||
let mcp_client = Arc::clone(&provider) as Arc<RemoteMcpClient>;
|
||||
|
||||
let config = AgentConfig {
|
||||
max_iterations: args.max_iter,
|
||||
model: args.model.unwrap_or_else(|| "llama3.2:latest".to_string()),
|
||||
..AgentConfig::default()
|
||||
};
|
||||
|
||||
let executor = AgentExecutor::new(provider, mcp_client, config);
|
||||
match executor.run(args.prompt).await {
|
||||
Ok(result) => {
|
||||
println!("\n✓ Agent completed in {} iterations", result.iterations);
|
||||
println!("\nFinal answer:\n{}", result.answer);
|
||||
Ok(())
|
||||
}
|
||||
Err(e) => Err(anyhow::anyhow!(e)),
|
||||
}
|
||||
}
|
||||
@@ -1,326 +0,0 @@
|
||||
use std::borrow::Cow;
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use async_trait::async_trait;
|
||||
use crossterm::{
|
||||
event::{DisableBracketedPaste, DisableMouseCapture, EnableBracketedPaste, EnableMouseCapture},
|
||||
execute,
|
||||
terminal::{EnterAlternateScreen, LeaveAlternateScreen, disable_raw_mode, enable_raw_mode},
|
||||
};
|
||||
use futures::stream;
|
||||
use owlen_core::{
|
||||
ChatStream, Error, Provider,
|
||||
config::{Config, McpMode},
|
||||
mcp::remote_client::RemoteMcpClient,
|
||||
mode::Mode,
|
||||
provider::ProviderManager,
|
||||
providers::OllamaProvider,
|
||||
session::{ControllerEvent, SessionController},
|
||||
storage::StorageManager,
|
||||
types::{ChatRequest, ChatResponse, Message, ModelInfo},
|
||||
};
|
||||
use owlen_tui::{
|
||||
ChatApp, SessionEvent,
|
||||
app::App as RuntimeApp,
|
||||
config,
|
||||
tui_controller::{TuiController, TuiRequest},
|
||||
ui,
|
||||
};
|
||||
use ratatui::{Terminal, prelude::CrosstermBackend};
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::commands::cloud::{load_runtime_credentials, set_env_var};
|
||||
|
||||
pub async fn launch(initial_mode: Mode) -> Result<()> {
|
||||
set_env_var("OWLEN_AUTO_CONSENT", "1");
|
||||
|
||||
let color_support = detect_terminal_color_support();
|
||||
let mut cfg = config::try_load_config().unwrap_or_default();
|
||||
let _ = cfg.refresh_mcp_servers(None);
|
||||
|
||||
if let Some(previous_theme) = apply_terminal_theme(&mut cfg, &color_support) {
|
||||
let term_label = match &color_support {
|
||||
TerminalColorSupport::Limited { term } => Cow::from(term.as_str()),
|
||||
TerminalColorSupport::Full => Cow::from("current terminal"),
|
||||
};
|
||||
eprintln!(
|
||||
"Terminal '{}' lacks full 256-color support. Using '{}' theme instead of '{}'.",
|
||||
term_label, BASIC_THEME_NAME, previous_theme
|
||||
);
|
||||
} else if let TerminalColorSupport::Limited { term } = &color_support {
|
||||
eprintln!(
|
||||
"Warning: terminal '{}' may not fully support 256-color themes.",
|
||||
term
|
||||
);
|
||||
}
|
||||
|
||||
cfg.validate()?;
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
load_runtime_credentials(&mut cfg, storage.clone()).await?;
|
||||
|
||||
let (tui_tx, _tui_rx) = mpsc::unbounded_channel::<TuiRequest>();
|
||||
let tui_controller = Arc::new(TuiController::new(tui_tx));
|
||||
|
||||
let provider = build_provider(&cfg)?;
|
||||
let mut offline_notice: Option<String> = None;
|
||||
let provider = match provider.health_check().await {
|
||||
Ok(_) => provider,
|
||||
Err(err) => {
|
||||
let hint = if matches!(cfg.mcp.mode, McpMode::RemotePreferred | McpMode::RemoteOnly)
|
||||
&& !cfg.effective_mcp_servers().is_empty()
|
||||
{
|
||||
"Ensure the configured MCP server is running and reachable."
|
||||
} else {
|
||||
"Ensure Ollama is running (`ollama serve`) and reachable at the configured base_url."
|
||||
};
|
||||
let notice =
|
||||
format!("Provider health check failed: {err}. {hint} Continuing in offline mode.");
|
||||
eprintln!("{notice}");
|
||||
offline_notice = Some(notice.clone());
|
||||
let fallback_model = cfg
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "offline".to_string());
|
||||
Arc::new(OfflineProvider::new(notice, fallback_model)) as Arc<dyn Provider>
|
||||
}
|
||||
};
|
||||
|
||||
let (controller_event_tx, controller_event_rx) = mpsc::unbounded_channel::<ControllerEvent>();
|
||||
let controller = SessionController::new(
|
||||
provider,
|
||||
cfg,
|
||||
storage.clone(),
|
||||
tui_controller,
|
||||
false,
|
||||
Some(controller_event_tx),
|
||||
)
|
||||
.await?;
|
||||
let provider_manager = Arc::new(ProviderManager::default());
|
||||
let mut runtime = RuntimeApp::new(provider_manager);
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller, controller_event_rx).await?;
|
||||
app.initialize_models().await?;
|
||||
if let Some(notice) = offline_notice.clone() {
|
||||
app.set_status_message(¬ice);
|
||||
app.set_system_status(notice);
|
||||
}
|
||||
|
||||
app.set_mode(initial_mode).await;
|
||||
|
||||
enable_raw_mode()?;
|
||||
let mut stdout = io::stdout();
|
||||
execute!(
|
||||
stdout,
|
||||
EnterAlternateScreen,
|
||||
EnableMouseCapture,
|
||||
EnableBracketedPaste
|
||||
)?;
|
||||
let backend = CrosstermBackend::new(stdout);
|
||||
let mut terminal = Terminal::new(backend)?;
|
||||
|
||||
let result = run_app(&mut terminal, &mut runtime, &mut app, &mut session_rx).await;
|
||||
|
||||
config::save_config(&app.config())?;
|
||||
|
||||
disable_raw_mode()?;
|
||||
execute!(
|
||||
terminal.backend_mut(),
|
||||
LeaveAlternateScreen,
|
||||
DisableMouseCapture,
|
||||
DisableBracketedPaste
|
||||
)?;
|
||||
terminal.show_cursor()?;
|
||||
|
||||
if let Err(err) = result {
|
||||
println!("{err:?}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
match cfg.mcp.mode {
|
||||
McpMode::RemotePreferred => {
|
||||
let remote_result = if let Some(mcp_server) = cfg.effective_mcp_servers().first() {
|
||||
RemoteMcpClient::new_with_config(mcp_server)
|
||||
} else {
|
||||
RemoteMcpClient::new()
|
||||
};
|
||||
|
||||
match remote_result {
|
||||
Ok(client) => Ok(Arc::new(client) as Arc<dyn Provider>),
|
||||
Err(err) if cfg.mcp.allow_fallback => {
|
||||
log::warn!(
|
||||
"Remote MCP client unavailable ({}); falling back to local provider.",
|
||||
err
|
||||
);
|
||||
build_local_provider(cfg)
|
||||
}
|
||||
Err(err) => Err(anyhow!(err)),
|
||||
}
|
||||
}
|
||||
McpMode::RemoteOnly => {
|
||||
let mcp_server = cfg.effective_mcp_servers().first().ok_or_else(|| {
|
||||
anyhow!("[[mcp_servers]] must be configured when [mcp].mode = \"remote_only\"")
|
||||
})?;
|
||||
let client = RemoteMcpClient::new_with_config(mcp_server)?;
|
||||
Ok(Arc::new(client) as Arc<dyn Provider>)
|
||||
}
|
||||
McpMode::LocalOnly | McpMode::Legacy => build_local_provider(cfg),
|
||||
McpMode::Disabled => Err(anyhow!(
|
||||
"MCP mode 'disabled' is not supported by the owlen TUI"
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_local_provider(cfg: &Config) -> Result<Arc<dyn Provider>> {
|
||||
let provider_name = cfg.general.default_provider.clone();
|
||||
let provider_cfg = cfg.provider(&provider_name).ok_or_else(|| {
|
||||
anyhow!(format!(
|
||||
"No provider configuration found for '{provider_name}' in [providers]"
|
||||
))
|
||||
})?;
|
||||
|
||||
match provider_cfg.provider_type.as_str() {
|
||||
"ollama" | "ollama_cloud" => {
|
||||
let provider = OllamaProvider::from_config(provider_cfg, Some(&cfg.general))?;
|
||||
Ok(Arc::new(provider) as Arc<dyn Provider>)
|
||||
}
|
||||
other => Err(anyhow!(format!(
|
||||
"Provider type '{other}' is not supported in legacy/local MCP mode"
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
const BASIC_THEME_NAME: &str = "ansi_basic";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum TerminalColorSupport {
|
||||
Full,
|
||||
Limited { term: String },
|
||||
}
|
||||
|
||||
fn detect_terminal_color_support() -> TerminalColorSupport {
|
||||
let term = std::env::var("TERM").unwrap_or_else(|_| "unknown".to_string());
|
||||
let colorterm = std::env::var("COLORTERM").unwrap_or_default();
|
||||
let term_lower = term.to_lowercase();
|
||||
let color_lower = colorterm.to_lowercase();
|
||||
|
||||
let supports_extended = term_lower.contains("256color")
|
||||
|| color_lower.contains("truecolor")
|
||||
|| color_lower.contains("24bit")
|
||||
|| color_lower.contains("fullcolor");
|
||||
|
||||
if supports_extended {
|
||||
TerminalColorSupport::Full
|
||||
} else {
|
||||
TerminalColorSupport::Limited { term }
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_terminal_theme(cfg: &mut Config, support: &TerminalColorSupport) -> Option<String> {
|
||||
match support {
|
||||
TerminalColorSupport::Full => None,
|
||||
TerminalColorSupport::Limited { .. } => {
|
||||
if cfg.ui.theme != BASIC_THEME_NAME {
|
||||
let previous = std::mem::replace(&mut cfg.ui.theme, BASIC_THEME_NAME.to_string());
|
||||
Some(previous)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct OfflineProvider {
|
||||
reason: String,
|
||||
placeholder_model: String,
|
||||
}
|
||||
|
||||
impl OfflineProvider {
|
||||
fn new(reason: String, placeholder_model: String) -> Self {
|
||||
Self {
|
||||
reason,
|
||||
placeholder_model,
|
||||
}
|
||||
}
|
||||
|
||||
fn friendly_response(&self, requested_model: &str) -> ChatResponse {
|
||||
let mut message = String::new();
|
||||
message.push_str("⚠️ Owlen is running in offline mode.\n\n");
|
||||
message.push_str(&self.reason);
|
||||
if !requested_model.is_empty() && requested_model != self.placeholder_model {
|
||||
message.push_str(&format!(
|
||||
"\n\nYou requested model '{}', but no providers are reachable.",
|
||||
requested_model
|
||||
));
|
||||
}
|
||||
message.push_str(
|
||||
"\n\nStart your preferred provider (e.g. `ollama serve`) or switch providers with `:provider` once connectivity is restored.",
|
||||
);
|
||||
|
||||
ChatResponse {
|
||||
message: Message::assistant(message),
|
||||
usage: None,
|
||||
is_streaming: false,
|
||||
is_final: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OfflineProvider {
|
||||
fn name(&self) -> &str {
|
||||
"offline"
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, Error> {
|
||||
Ok(vec![ModelInfo {
|
||||
id: self.placeholder_model.clone(),
|
||||
provider: "offline".to_string(),
|
||||
name: format!("Offline (fallback: {})", self.placeholder_model),
|
||||
description: Some("Placeholder model used while no providers are reachable".into()),
|
||||
context_window: None,
|
||||
capabilities: vec![],
|
||||
supports_tools: false,
|
||||
}])
|
||||
}
|
||||
|
||||
async fn send_prompt(&self, request: ChatRequest) -> Result<ChatResponse, Error> {
|
||||
Ok(self.friendly_response(&request.model))
|
||||
}
|
||||
|
||||
async fn stream_prompt(&self, request: ChatRequest) -> Result<ChatStream, Error> {
|
||||
let response = self.friendly_response(&request.model);
|
||||
Ok(Box::pin(stream::iter(vec![Ok(response)])))
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<(), Error> {
|
||||
Err(Error::Provider(anyhow!(
|
||||
"offline provider cannot reach any backing models"
|
||||
)))
|
||||
}
|
||||
|
||||
fn as_any(&self) -> &(dyn std::any::Any + Send + Sync) {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_app(
|
||||
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
runtime: &mut RuntimeApp,
|
||||
app: &mut ChatApp,
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
let mut render = |terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
|
||||
state: &mut ChatApp|
|
||||
-> Result<()> {
|
||||
terminal.draw(|f| ui::render_chat(f, state))?;
|
||||
Ok(())
|
||||
};
|
||||
|
||||
runtime.run(terminal, app, session_rx, &mut render).await?;
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,16 +0,0 @@
|
||||
//! Owlen CLI entrypoint optimised for code-first workflows.
|
||||
#![allow(dead_code, unused_imports)]
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::Result;
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
bootstrap::launch(Mode::Code).await
|
||||
}
|
||||
@@ -1,479 +0,0 @@
|
||||
use std::ffi::OsStr;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Context, Result, anyhow, bail};
|
||||
use clap::Subcommand;
|
||||
use owlen_core::LlmProvider;
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{
|
||||
self as core_config, Config, OLLAMA_CLOUD_API_KEY_ENV, OLLAMA_CLOUD_BASE_URL,
|
||||
OLLAMA_CLOUD_ENDPOINT_KEY, OLLAMA_MODE_KEY,
|
||||
};
|
||||
use owlen_core::credentials::{ApiCredentials, CredentialManager, OLLAMA_CLOUD_CREDENTIAL_ID};
|
||||
use owlen_core::encryption;
|
||||
use owlen_core::providers::OllamaProvider;
|
||||
use owlen_core::storage::StorageManager;
|
||||
use serde_json::Value;
|
||||
|
||||
const DEFAULT_CLOUD_ENDPOINT: &str = OLLAMA_CLOUD_BASE_URL;
|
||||
const CLOUD_ENDPOINT_KEY: &str = OLLAMA_CLOUD_ENDPOINT_KEY;
|
||||
const CLOUD_PROVIDER_KEY: &str = "ollama_cloud";
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum CloudCommand {
|
||||
/// Configure Ollama Cloud credentials
|
||||
Setup {
|
||||
/// API key passed directly on the command line (prompted when omitted)
|
||||
#[arg(long)]
|
||||
api_key: Option<String>,
|
||||
/// Override the cloud endpoint (default: https://ollama.com)
|
||||
#[arg(long)]
|
||||
endpoint: Option<String>,
|
||||
/// Provider name to configure (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
/// Overwrite the provider base URL with the cloud endpoint
|
||||
#[arg(long)]
|
||||
force_cloud_base_url: bool,
|
||||
},
|
||||
/// Check connectivity to Ollama Cloud
|
||||
Status {
|
||||
/// Provider name to check (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// List available cloud-hosted models
|
||||
Models {
|
||||
/// Provider name to query (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
/// Remove stored Ollama Cloud credentials
|
||||
Logout {
|
||||
/// Provider name to clear (default: ollama_cloud)
|
||||
#[arg(long, default_value = "ollama_cloud")]
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub async fn run_cloud_command(command: CloudCommand) -> Result<()> {
|
||||
match command {
|
||||
CloudCommand::Setup {
|
||||
api_key,
|
||||
endpoint,
|
||||
provider,
|
||||
force_cloud_base_url,
|
||||
} => setup(provider, api_key, endpoint, force_cloud_base_url).await,
|
||||
CloudCommand::Status { provider } => status(provider).await,
|
||||
CloudCommand::Models { provider } => models(provider).await,
|
||||
CloudCommand::Logout { provider } => logout(provider).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn setup(
|
||||
provider: String,
|
||||
api_key: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
force_cloud_base_url: bool,
|
||||
) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let endpoint =
|
||||
normalize_endpoint(&endpoint.unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string()));
|
||||
|
||||
let base_changed = {
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, &endpoint, force_cloud_base_url)
|
||||
};
|
||||
|
||||
let key = match api_key {
|
||||
Some(value) if !value.trim().is_empty() => value,
|
||||
_ => {
|
||||
let prompt = format!("Enter API key for {provider}: ");
|
||||
encryption::prompt_password(&prompt)?
|
||||
}
|
||||
};
|
||||
|
||||
if config.privacy.encrypt_local_data {
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
let manager = unlock_credential_manager(&config, storage.clone())?;
|
||||
let credentials = ApiCredentials {
|
||||
api_key: key.clone(),
|
||||
endpoint: endpoint.clone(),
|
||||
};
|
||||
manager
|
||||
.store_credentials(OLLAMA_CLOUD_CREDENTIAL_ID, &credentials)
|
||||
.await?;
|
||||
// Ensure plaintext key is not persisted to disk.
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
}
|
||||
} else if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = Some(key.clone());
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Saved Ollama configuration for provider '{provider}'.");
|
||||
if config.privacy.encrypt_local_data {
|
||||
println!("API key stored securely in the encrypted credential vault.");
|
||||
} else {
|
||||
println!("API key stored in plaintext configuration (encryption disabled).");
|
||||
}
|
||||
if !force_cloud_base_url && !base_changed {
|
||||
println!(
|
||||
"Local base URL preserved; cloud endpoint stored as {}.",
|
||||
CLOUD_ENDPOINT_KEY
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
let manager = if config.privacy.encrypt_local_data {
|
||||
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let api_key = hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint.clone());
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.health_check().await {
|
||||
Ok(_) => {
|
||||
println!("✓ Connected to {provider} ({})", endpoint);
|
||||
if api_key.is_none() && config.privacy.encrypt_local_data {
|
||||
println!(
|
||||
"Warning: No API key stored; connection succeeded via environment variables."
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!("✗ Failed to reach {provider}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn models(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
let manager = if config.privacy.encrypt_local_data {
|
||||
Some(unlock_credential_manager(&config, storage.clone())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
hydrate_api_key(&mut config, manager.as_ref()).await?;
|
||||
|
||||
{
|
||||
let entry = ensure_provider_entry(&mut config, &provider);
|
||||
entry.enabled = true;
|
||||
configure_cloud_endpoint(entry, DEFAULT_CLOUD_ENDPOINT, false);
|
||||
}
|
||||
|
||||
let provider_cfg = config
|
||||
.provider(&provider)
|
||||
.cloned()
|
||||
.ok_or_else(|| anyhow!("Provider '{provider}' is not configured"))?;
|
||||
|
||||
let endpoint =
|
||||
resolve_cloud_endpoint(&provider_cfg).unwrap_or_else(|| DEFAULT_CLOUD_ENDPOINT.to_string());
|
||||
let mut runtime_cfg = provider_cfg.clone();
|
||||
runtime_cfg.base_url = Some(endpoint);
|
||||
runtime_cfg.extra.insert(
|
||||
OLLAMA_MODE_KEY.to_string(),
|
||||
Value::String("cloud".to_string()),
|
||||
);
|
||||
|
||||
let ollama = OllamaProvider::from_config(&runtime_cfg, Some(&config.general))
|
||||
.with_context(|| "Failed to construct Ollama provider. Run `owlen cloud setup` first.")?;
|
||||
|
||||
match ollama.list_models().await {
|
||||
Ok(models) => {
|
||||
if models.is_empty() {
|
||||
println!("No cloud models reported by '{}'.", provider);
|
||||
} else {
|
||||
println!("Models available via '{}':", provider);
|
||||
for model in models {
|
||||
if let Some(description) = &model.description {
|
||||
println!(" - {} ({})", model.id, description);
|
||||
} else {
|
||||
println!(" - {}", model.id);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
bail!("Failed to list models: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn logout(provider: String) -> Result<()> {
|
||||
let provider = canonical_provider_name(&provider);
|
||||
let mut config = crate::config::try_load_config().unwrap_or_default();
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
|
||||
if config.privacy.encrypt_local_data {
|
||||
let manager = unlock_credential_manager(&config, storage.clone())?;
|
||||
manager
|
||||
.delete_credentials(OLLAMA_CLOUD_CREDENTIAL_ID)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut(&provider) {
|
||||
entry.api_key = None;
|
||||
entry.enabled = false;
|
||||
}
|
||||
|
||||
crate::config::save_config(&config)?;
|
||||
println!("Cleared credentials for provider '{provider}'.");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn ensure_provider_entry<'a>(config: &'a mut Config, provider: &str) -> &'a mut ProviderConfig {
|
||||
core_config::ensure_provider_config_mut(config, provider)
|
||||
}
|
||||
|
||||
fn configure_cloud_endpoint(entry: &mut ProviderConfig, endpoint: &str, force: bool) -> bool {
|
||||
let normalized = normalize_endpoint(endpoint);
|
||||
let previous_base = entry.base_url.clone();
|
||||
entry.extra.insert(
|
||||
CLOUD_ENDPOINT_KEY.to_string(),
|
||||
Value::String(normalized.clone()),
|
||||
);
|
||||
|
||||
if entry.api_key_env.is_none() {
|
||||
entry.api_key_env = Some(OLLAMA_CLOUD_API_KEY_ENV.to_string());
|
||||
}
|
||||
|
||||
if force
|
||||
|| entry
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
entry.base_url = Some(normalized.clone());
|
||||
}
|
||||
|
||||
if force {
|
||||
entry.enabled = true;
|
||||
}
|
||||
|
||||
entry.base_url != previous_base
|
||||
}
|
||||
|
||||
fn resolve_cloud_endpoint(cfg: &ProviderConfig) -> Option<String> {
|
||||
if let Some(value) = cfg
|
||||
.extra
|
||||
.get(CLOUD_ENDPOINT_KEY)
|
||||
.and_then(|value| value.as_str())
|
||||
.map(normalize_endpoint)
|
||||
{
|
||||
return Some(value);
|
||||
}
|
||||
|
||||
cfg.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim_end_matches('/').to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
fn normalize_endpoint(endpoint: &str) -> String {
|
||||
let trimmed = endpoint.trim().trim_end_matches('/');
|
||||
if trimmed.is_empty() {
|
||||
DEFAULT_CLOUD_ENDPOINT.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_name(provider: &str) -> String {
|
||||
let normalized = provider.trim().to_ascii_lowercase().replace('-', "_");
|
||||
match normalized.as_str() {
|
||||
"" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
"ollama_cloud" => CLOUD_PROVIDER_KEY.to_string(),
|
||||
value => value.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn set_env_var<K, V>(key: K, value: V)
|
||||
where
|
||||
K: AsRef<OsStr>,
|
||||
V: AsRef<OsStr>,
|
||||
{
|
||||
// Safety: the CLI updates process-wide environment variables during startup while no
|
||||
// other threads are mutating the environment.
|
||||
unsafe {
|
||||
std::env::set_var(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn set_env_if_missing(var: &str, value: &str) {
|
||||
if std::env::var(var)
|
||||
.map(|v| v.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
set_env_var(var, value);
|
||||
}
|
||||
}
|
||||
|
||||
fn unlock_credential_manager(
|
||||
config: &Config,
|
||||
storage: Arc<StorageManager>,
|
||||
) -> Result<Arc<CredentialManager>> {
|
||||
if !config.privacy.encrypt_local_data {
|
||||
bail!("Credential manager requested but encryption is disabled");
|
||||
}
|
||||
|
||||
let secure_path = vault_path(&storage)?;
|
||||
let handle = unlock_vault(&secure_path)?;
|
||||
let master_key = Arc::new(handle.data.master_key.clone());
|
||||
Ok(Arc::new(CredentialManager::new(
|
||||
storage,
|
||||
master_key.clone(),
|
||||
)))
|
||||
}
|
||||
|
||||
fn vault_path(storage: &StorageManager) -> Result<PathBuf> {
|
||||
let base_dir = storage
|
||||
.database_path()
|
||||
.parent()
|
||||
.map(|p| p.to_path_buf())
|
||||
.or_else(dirs::data_local_dir)
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
Ok(base_dir.join("encrypted_data.json"))
|
||||
}
|
||||
|
||||
fn unlock_vault(path: &Path) -> Result<encryption::VaultHandle> {
|
||||
use std::env;
|
||||
|
||||
if path.exists() {
|
||||
if let Some(password) = env::var("OWLEN_MASTER_PASSWORD")
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|password| !password.is_empty())
|
||||
{
|
||||
return encryption::unlock_with_password(path.to_path_buf(), &password)
|
||||
.context("Failed to unlock vault with OWLEN_MASTER_PASSWORD");
|
||||
}
|
||||
|
||||
for attempt in 0..3 {
|
||||
let password = encryption::prompt_password("Enter master password: ")?;
|
||||
match encryption::unlock_with_password(path.to_path_buf(), &password) {
|
||||
Ok(handle) => {
|
||||
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||
return Ok(handle);
|
||||
}
|
||||
Err(err) => {
|
||||
eprintln!("Failed to unlock vault: {err}");
|
||||
if attempt == 2 {
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bail!("Unable to unlock encrypted credential vault");
|
||||
}
|
||||
|
||||
let handle = encryption::unlock_interactive(path.to_path_buf())?;
|
||||
if env::var("OWLEN_MASTER_PASSWORD")
|
||||
.map(|v| v.trim().is_empty())
|
||||
.unwrap_or(true)
|
||||
{
|
||||
let password = encryption::prompt_password("Cache master password for this session: ")?;
|
||||
set_env_var("OWLEN_MASTER_PASSWORD", password);
|
||||
}
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
async fn hydrate_api_key(
|
||||
config: &mut Config,
|
||||
manager: Option<&Arc<CredentialManager>>,
|
||||
) -> Result<Option<String>> {
|
||||
let credentials = match manager {
|
||||
Some(manager) => manager.get_credentials(OLLAMA_CLOUD_CREDENTIAL_ID).await?,
|
||||
None => None,
|
||||
};
|
||||
|
||||
if let Some(credentials) = credentials {
|
||||
let key = credentials.api_key.trim().to_string();
|
||||
if !key.is_empty() {
|
||||
set_env_if_missing("OLLAMA_API_KEY", &key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", &key);
|
||||
}
|
||||
|
||||
let cfg = core_config::ensure_provider_config_mut(config, CLOUD_PROVIDER_KEY);
|
||||
configure_cloud_endpoint(cfg, &credentials.endpoint, false);
|
||||
return Ok(Some(key));
|
||||
}
|
||||
|
||||
if let Some(key) = config
|
||||
.provider(CLOUD_PROVIDER_KEY)
|
||||
.and_then(|cfg| cfg.api_key.as_ref())
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
set_env_if_missing("OLLAMA_API_KEY", key);
|
||||
set_env_if_missing("OLLAMA_CLOUD_API_KEY", key);
|
||||
return Ok(Some(key.to_string()));
|
||||
}
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
pub async fn load_runtime_credentials(
|
||||
config: &mut Config,
|
||||
storage: Arc<StorageManager>,
|
||||
) -> Result<()> {
|
||||
if config.privacy.encrypt_local_data {
|
||||
let manager = unlock_credential_manager(config, storage.clone())?;
|
||||
hydrate_api_key(config, Some(&manager)).await?;
|
||||
} else {
|
||||
hydrate_api_key(config, None).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn canonicalises_provider_names() {
|
||||
assert_eq!(canonical_provider_name("OLLAMA_CLOUD"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(" ollama-cloud"), CLOUD_PROVIDER_KEY);
|
||||
assert_eq!(canonical_provider_name(""), CLOUD_PROVIDER_KEY);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
//! Command implementations for the `owlen` CLI.
|
||||
|
||||
pub mod cloud;
|
||||
pub mod providers;
|
||||
@@ -1,651 +0,0 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand};
|
||||
use owlen_core::ProviderConfig;
|
||||
use owlen_core::config::{self as core_config, Config};
|
||||
use owlen_core::provider::{
|
||||
AnnotatedModelInfo, ModelProvider, ProviderManager, ProviderStatus, ProviderType,
|
||||
};
|
||||
use owlen_core::storage::StorageManager;
|
||||
use owlen_providers::ollama::{OllamaCloudProvider, OllamaLocalProvider};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
use super::cloud;
|
||||
|
||||
/// CLI subcommands for provider management.
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum ProvidersCommand {
|
||||
/// List configured providers and their metadata.
|
||||
List,
|
||||
/// Run health checks against providers.
|
||||
Status {
|
||||
/// Optional provider identifier to check.
|
||||
#[arg(value_name = "PROVIDER")]
|
||||
provider: Option<String>,
|
||||
},
|
||||
/// Enable a provider in the configuration.
|
||||
Enable {
|
||||
/// Provider identifier to enable.
|
||||
provider: String,
|
||||
},
|
||||
/// Disable a provider in the configuration.
|
||||
Disable {
|
||||
/// Provider identifier to disable.
|
||||
provider: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// Arguments for the `owlen models` command.
|
||||
#[derive(Debug, Default, Args)]
|
||||
pub struct ModelsArgs {
|
||||
/// Restrict output to a specific provider.
|
||||
#[arg(long)]
|
||||
pub provider: Option<String>,
|
||||
}
|
||||
|
||||
pub async fn run_providers_command(command: ProvidersCommand) -> Result<()> {
|
||||
match command {
|
||||
ProvidersCommand::List => list_providers(),
|
||||
ProvidersCommand::Status { provider } => status_providers(provider.as_deref()).await,
|
||||
ProvidersCommand::Enable { provider } => toggle_provider(&provider, true),
|
||||
ProvidersCommand::Disable { provider } => toggle_provider(&provider, false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn run_models_command(args: ModelsArgs) -> Result<()> {
|
||||
list_models(args.provider.as_deref()).await
|
||||
}
|
||||
|
||||
fn list_providers() -> Result<()> {
|
||||
let config = tui_config::try_load_config().unwrap_or_default();
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for (id, cfg) in &config.providers {
|
||||
let type_label = describe_provider_type(id, cfg);
|
||||
let auth_label = describe_auth(cfg, requires_auth(id, cfg));
|
||||
let enabled = if cfg.enabled { "yes" } else { "no" };
|
||||
let default = if id == &default_provider { "*" } else { "" };
|
||||
let base = cfg
|
||||
.base_url
|
||||
.as_ref()
|
||||
.map(|value| value.trim().to_string())
|
||||
.unwrap_or_else(|| "-".to_string());
|
||||
|
||||
rows.push(ProviderListRow {
|
||||
id: id.to_string(),
|
||||
type_label,
|
||||
enabled: enabled.to_string(),
|
||||
default: default.to_string(),
|
||||
auth: auth_label,
|
||||
base_url: base,
|
||||
});
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let enabled_width = rows
|
||||
.iter()
|
||||
.map(|row| row.enabled.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Enabled".len());
|
||||
let default_width = rows
|
||||
.iter()
|
||||
.map(|row| row.default.len())
|
||||
.max()
|
||||
.unwrap_or(7)
|
||||
.max("Default".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.type_label.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let auth_width = rows
|
||||
.iter()
|
||||
.map(|row| row.auth.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Auth".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} Base URL",
|
||||
"Provider",
|
||||
"Enabled",
|
||||
"Default",
|
||||
"Type",
|
||||
"Auth",
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
println!(
|
||||
"{:<id_width$} {:<enabled_width$} {:<default_width$} {:<type_width$} {:<auth_width$} {}",
|
||||
row.id,
|
||||
row.enabled,
|
||||
row.default,
|
||||
row.type_label,
|
||||
row.auth,
|
||||
row.base_url,
|
||||
id_width = id_width,
|
||||
enabled_width = enabled_width,
|
||||
default_width = default_width,
|
||||
type_width = type_width,
|
||||
auth_width = auth_width,
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn status_providers(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let health = manager.refresh_health().await;
|
||||
|
||||
let mut rows = Vec::new();
|
||||
for record in records {
|
||||
let status = health.get(&record.id).copied();
|
||||
rows.push(ProviderStatusRow::from_record(record, status));
|
||||
}
|
||||
|
||||
rows.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
print_status_rows(&rows);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn list_models(filter: Option<&str>) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let filter = filter.map(canonical_provider_id);
|
||||
verify_provider_filter(&config, filter.as_deref())?;
|
||||
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
cloud::load_runtime_credentials(&mut config, storage.clone()).await?;
|
||||
|
||||
let manager = ProviderManager::new(&config);
|
||||
let records = register_enabled_providers(&manager, &config, filter.as_deref()).await?;
|
||||
let models = manager
|
||||
.list_all_models()
|
||||
.await
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
let statuses = manager.provider_statuses().await;
|
||||
|
||||
print_models(records, models, statuses);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn verify_provider_filter(config: &Config, filter: Option<&str>) -> Result<()> {
|
||||
if let Some(filter) = filter
|
||||
&& !config.providers.contains_key(filter)
|
||||
{
|
||||
return Err(anyhow!(
|
||||
"Provider '{}' is not defined in configuration.",
|
||||
filter
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn toggle_provider(provider: &str, enable: bool) -> Result<()> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
let canonical = canonical_provider_id(provider);
|
||||
if canonical.is_empty() {
|
||||
return Err(anyhow!("Provider name cannot be empty."));
|
||||
}
|
||||
|
||||
let previous_default = config.general.default_provider.clone();
|
||||
let previous_fallback_enabled = config.providers.get("ollama_local").map(|cfg| cfg.enabled);
|
||||
|
||||
let previous_enabled;
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
previous_enabled = entry.enabled;
|
||||
if previous_enabled == enable {
|
||||
println!(
|
||||
"Provider '{}' is already {}.",
|
||||
canonical,
|
||||
if enable { "enabled" } else { "disabled" }
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
entry.enabled = enable;
|
||||
}
|
||||
|
||||
if !enable && config.general.default_provider == canonical {
|
||||
if let Some(candidate) = choose_fallback_provider(&config, &canonical) {
|
||||
config.general.default_provider = candidate.clone();
|
||||
println!(
|
||||
"Default provider set to '{}' because '{}' was disabled.",
|
||||
candidate, canonical
|
||||
);
|
||||
} else {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
entry.enabled = true;
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
println!(
|
||||
"Enabled 'ollama_local' and made it default because no other providers are active."
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if let Err(err) = config.validate() {
|
||||
{
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, &canonical);
|
||||
entry.enabled = previous_enabled;
|
||||
}
|
||||
config.general.default_provider = previous_default;
|
||||
if let Some(enabled) = previous_fallback_enabled
|
||||
&& let Some(entry) = config.providers.get_mut("ollama_local")
|
||||
{
|
||||
entry.enabled = enabled;
|
||||
}
|
||||
return Err(anyhow!(err));
|
||||
}
|
||||
|
||||
tui_config::save_config(&config).map_err(|err| anyhow!(err))?;
|
||||
|
||||
println!(
|
||||
"{} provider '{}'.",
|
||||
if enable { "Enabled" } else { "Disabled" },
|
||||
canonical
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn choose_fallback_provider(config: &Config, exclude: &str) -> Option<String> {
|
||||
if exclude != "ollama_local"
|
||||
&& let Some(cfg) = config.providers.get("ollama_local")
|
||||
&& cfg.enabled
|
||||
{
|
||||
return Some("ollama_local".to_string());
|
||||
}
|
||||
|
||||
let mut candidates: Vec<String> = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && id.as_str() != exclude)
|
||||
.map(|(id, _)| id.clone())
|
||||
.collect();
|
||||
candidates.sort();
|
||||
candidates.into_iter().next()
|
||||
}
|
||||
|
||||
async fn register_enabled_providers(
|
||||
manager: &ProviderManager,
|
||||
config: &Config,
|
||||
filter: Option<&str>,
|
||||
) -> Result<Vec<ProviderRecord>> {
|
||||
let default_provider = canonical_provider_id(&config.general.default_provider);
|
||||
let mut records = Vec::new();
|
||||
|
||||
for (id, cfg) in &config.providers {
|
||||
if let Some(filter) = filter
|
||||
&& id != filter
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut record = ProviderRecord::from_config(id, cfg, id == &default_provider);
|
||||
if !cfg.enabled {
|
||||
records.push(record);
|
||||
continue;
|
||||
}
|
||||
|
||||
match instantiate_provider(id, cfg) {
|
||||
Ok(provider) => {
|
||||
let metadata = provider.metadata().clone();
|
||||
record.provider_type_label = provider_type_label(metadata.provider_type);
|
||||
record.requires_auth = metadata.requires_auth;
|
||||
record.metadata = Some(metadata);
|
||||
manager.register_provider(provider).await;
|
||||
}
|
||||
Err(err) => {
|
||||
record.registration_error = Some(err.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
records.push(record);
|
||||
}
|
||||
|
||||
records.sort_by(|a, b| a.id.cmp(&b.id));
|
||||
Ok(records)
|
||||
}
|
||||
|
||||
fn instantiate_provider(id: &str, cfg: &ProviderConfig) -> Result<Arc<dyn ModelProvider>> {
|
||||
let kind = cfg.provider_type.trim().to_ascii_lowercase();
|
||||
if kind == "ollama" || id == "ollama_local" {
|
||||
let provider = OllamaLocalProvider::new(cfg.base_url.clone(), None, None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else if kind == "ollama_cloud" || id == "ollama_cloud" {
|
||||
let provider = OllamaCloudProvider::new(cfg.base_url.clone(), cfg.api_key.clone(), None)
|
||||
.map_err(|err| anyhow!(err))?;
|
||||
Ok(Arc::new(provider))
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"Provider '{}' uses unsupported type '{}'.",
|
||||
id,
|
||||
if kind.is_empty() {
|
||||
"unknown"
|
||||
} else {
|
||||
kind.as_str()
|
||||
}
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
fn describe_provider_type(id: &str, cfg: &ProviderConfig) -> String {
|
||||
if cfg.provider_type.trim().eq_ignore_ascii_case("ollama") || id.ends_with("_local") {
|
||||
"Local".to_string()
|
||||
} else if cfg
|
||||
.provider_type
|
||||
.trim()
|
||||
.eq_ignore_ascii_case("ollama_cloud")
|
||||
|| id.contains("cloud")
|
||||
{
|
||||
"Cloud".to_string()
|
||||
} else {
|
||||
"Custom".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn requires_auth(id: &str, cfg: &ProviderConfig) -> bool {
|
||||
cfg.api_key.is_some()
|
||||
|| cfg.api_key_env.is_some()
|
||||
|| matches!(id, "ollama_cloud" | "openai" | "anthropic")
|
||||
}
|
||||
|
||||
fn describe_auth(cfg: &ProviderConfig, required: bool) -> String {
|
||||
if let Some(env) = cfg
|
||||
.api_key_env
|
||||
.as_ref()
|
||||
.map(|value| value.trim())
|
||||
.filter(|value| !value.is_empty())
|
||||
{
|
||||
format!("env:{env}")
|
||||
} else if cfg
|
||||
.api_key
|
||||
.as_ref()
|
||||
.map(|value| !value.trim().is_empty())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
"config".to_string()
|
||||
} else if required {
|
||||
"required".to_string()
|
||||
} else {
|
||||
"-".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_provider_id(raw: &str) -> String {
|
||||
let trimmed = raw.trim().to_ascii_lowercase();
|
||||
if trimmed.is_empty() {
|
||||
return trimmed;
|
||||
}
|
||||
|
||||
match trimmed.as_str() {
|
||||
"ollama" | "ollama-local" => "ollama_local".to_string(),
|
||||
"ollama_cloud" | "ollama-cloud" => "ollama_cloud".to_string(),
|
||||
other => other.replace('-', "_"),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_type_label(provider_type: ProviderType) -> String {
|
||||
match provider_type {
|
||||
ProviderType::Local => "Local".to_string(),
|
||||
ProviderType::Cloud => "Cloud".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn provider_status_strings(status: ProviderStatus) -> (&'static str, &'static str) {
|
||||
match status {
|
||||
ProviderStatus::Available => ("OK", "available"),
|
||||
ProviderStatus::Unavailable => ("ERR", "unavailable"),
|
||||
ProviderStatus::RequiresSetup => ("SETUP", "requires setup"),
|
||||
}
|
||||
}
|
||||
|
||||
fn print_status_rows(rows: &[ProviderStatusRow]) {
|
||||
let id_width = rows
|
||||
.iter()
|
||||
.map(|row| row.id.len())
|
||||
.max()
|
||||
.unwrap_or(8)
|
||||
.max("Provider".len());
|
||||
let type_width = rows
|
||||
.iter()
|
||||
.map(|row| row.provider_type.len())
|
||||
.max()
|
||||
.unwrap_or(4)
|
||||
.max("Type".len());
|
||||
let status_width = rows
|
||||
.iter()
|
||||
.map(|row| row.indicator.len() + 1 + row.status_label.len())
|
||||
.max()
|
||||
.unwrap_or(6)
|
||||
.max("State".len());
|
||||
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} Details",
|
||||
"Provider",
|
||||
"Def",
|
||||
"Type",
|
||||
"State",
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
let def = if row.default_provider { "*" } else { "-" };
|
||||
let details = row.detail.as_deref().unwrap_or("-");
|
||||
println!(
|
||||
"{:<id_width$} {:<4} {:<type_width$} {:<status_width$} {}",
|
||||
row.id,
|
||||
def,
|
||||
row.provider_type,
|
||||
format!("{} {}", row.indicator, row.status_label),
|
||||
details,
|
||||
id_width = id_width,
|
||||
type_width = type_width,
|
||||
status_width = status_width,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_models(
|
||||
records: Vec<ProviderRecord>,
|
||||
models: Vec<AnnotatedModelInfo>,
|
||||
statuses: HashMap<String, ProviderStatus>,
|
||||
) {
|
||||
let mut grouped: HashMap<String, Vec<AnnotatedModelInfo>> = HashMap::new();
|
||||
for info in models {
|
||||
grouped
|
||||
.entry(info.provider_id.clone())
|
||||
.or_default()
|
||||
.push(info);
|
||||
}
|
||||
|
||||
for record in records {
|
||||
let status = statuses.get(&record.id).copied().or_else(|| {
|
||||
if record.metadata.is_some() && record.registration_error.is_none() && record.enabled {
|
||||
Some(ProviderStatus::Unavailable)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
});
|
||||
|
||||
let (indicator, label, status_value) = if !record.enabled {
|
||||
("-", "disabled", None)
|
||||
} else if record.registration_error.is_some() {
|
||||
("ERR", "error", None)
|
||||
} else if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
(indicator, label, Some(status))
|
||||
} else {
|
||||
("?", "unknown", None)
|
||||
};
|
||||
|
||||
let title = if record.default_provider {
|
||||
format!("{} (default)", record.id)
|
||||
} else {
|
||||
record.id.clone()
|
||||
};
|
||||
println!(
|
||||
"{} {} [{}] {}",
|
||||
indicator, title, record.provider_type_label, label
|
||||
);
|
||||
|
||||
if let Some(err) = &record.registration_error {
|
||||
println!(" error: {}", err);
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if !record.enabled {
|
||||
println!(" provider disabled");
|
||||
println!();
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(entries) = grouped.get(&record.id) {
|
||||
let mut entries = entries.clone();
|
||||
entries.sort_by(|a, b| a.model.name.cmp(&b.model.name));
|
||||
if entries.is_empty() {
|
||||
println!(" (no models reported)");
|
||||
} else {
|
||||
for entry in entries {
|
||||
let mut line = format!(" - {}", entry.model.name);
|
||||
if let Some(description) = &entry.model.description
|
||||
&& !description.trim().is_empty()
|
||||
{
|
||||
line.push_str(&format!(" — {}", description.trim()));
|
||||
}
|
||||
println!("{}", line);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
println!(" (no models reported)");
|
||||
}
|
||||
|
||||
if let Some(ProviderStatus::RequiresSetup) = status_value
|
||||
&& record.requires_auth
|
||||
{
|
||||
println!(" configure provider credentials or API key");
|
||||
}
|
||||
println!();
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderListRow {
|
||||
id: String,
|
||||
type_label: String,
|
||||
enabled: String,
|
||||
default: String,
|
||||
auth: String,
|
||||
base_url: String,
|
||||
}
|
||||
|
||||
struct ProviderRecord {
|
||||
id: String,
|
||||
enabled: bool,
|
||||
default_provider: bool,
|
||||
provider_type_label: String,
|
||||
requires_auth: bool,
|
||||
registration_error: Option<String>,
|
||||
metadata: Option<owlen_core::provider::ProviderMetadata>,
|
||||
}
|
||||
|
||||
impl ProviderRecord {
|
||||
fn from_config(id: &str, cfg: &ProviderConfig, default_provider: bool) -> Self {
|
||||
Self {
|
||||
id: id.to_string(),
|
||||
enabled: cfg.enabled,
|
||||
default_provider,
|
||||
provider_type_label: describe_provider_type(id, cfg),
|
||||
requires_auth: requires_auth(id, cfg),
|
||||
registration_error: None,
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ProviderStatusRow {
|
||||
id: String,
|
||||
provider_type: String,
|
||||
default_provider: bool,
|
||||
indicator: String,
|
||||
status_label: String,
|
||||
detail: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderStatusRow {
|
||||
fn from_record(record: ProviderRecord, status: Option<ProviderStatus>) -> Self {
|
||||
if !record.enabled {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "-".to_string(),
|
||||
status_label: "disabled".to_string(),
|
||||
detail: None,
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(err) = record.registration_error {
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "ERR".to_string(),
|
||||
status_label: "error".to_string(),
|
||||
detail: Some(err),
|
||||
};
|
||||
}
|
||||
|
||||
if let Some(status) = status {
|
||||
let (indicator, label) = provider_status_strings(status);
|
||||
return Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: indicator.to_string(),
|
||||
status_label: label.to_string(),
|
||||
detail: if matches!(status, ProviderStatus::RequiresSetup) && record.requires_auth {
|
||||
Some("credentials required".to_string())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
Self {
|
||||
id: record.id,
|
||||
provider_type: record.provider_type_label,
|
||||
default_provider: record.default_provider,
|
||||
indicator: "?".to_string(),
|
||||
status_label: "unknown".to_string(),
|
||||
detail: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
//! Library portion of the `owlen-cli` crate.
|
||||
//!
|
||||
//! It currently only re‑exports the `agent` module used by the standalone
|
||||
//! `owlen-agent` binary. Additional shared functionality can be added here in
|
||||
//! the future.
|
||||
|
||||
// Re-export agent module from owlen-core
|
||||
pub use owlen_core::agent;
|
||||
@@ -1,228 +0,0 @@
|
||||
#![allow(clippy::collapsible_if)] // TODO: Remove once Rust 2024 let-chains are available
|
||||
|
||||
//! OWLEN CLI - Chat TUI client
|
||||
|
||||
mod bootstrap;
|
||||
mod commands;
|
||||
mod mcp;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, Subcommand};
|
||||
use commands::{
|
||||
cloud::{CloudCommand, run_cloud_command},
|
||||
providers::{ModelsArgs, ProvidersCommand, run_models_command, run_providers_command},
|
||||
};
|
||||
use mcp::{McpCommand, run_mcp_command};
|
||||
use owlen_core::config as core_config;
|
||||
use owlen_core::config::McpMode;
|
||||
use owlen_core::mode::Mode;
|
||||
use owlen_tui::config;
|
||||
|
||||
/// Owlen - Terminal UI for LLM chat
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "owlen")]
|
||||
#[command(about = "Terminal UI for LLM chat via MCP", long_about = None)]
|
||||
struct Args {
|
||||
/// Start in code mode (enables all tools)
|
||||
#[arg(long, short = 'c')]
|
||||
code: bool,
|
||||
#[command(subcommand)]
|
||||
command: Option<OwlenCommand>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum OwlenCommand {
|
||||
/// Inspect or upgrade configuration files
|
||||
#[command(subcommand)]
|
||||
Config(ConfigCommand),
|
||||
/// Manage Ollama Cloud credentials
|
||||
#[command(subcommand)]
|
||||
Cloud(CloudCommand),
|
||||
/// Manage model providers
|
||||
#[command(subcommand)]
|
||||
Providers(ProvidersCommand),
|
||||
/// List models exposed by configured providers
|
||||
Models(ModelsArgs),
|
||||
/// Manage MCP server registrations
|
||||
#[command(subcommand)]
|
||||
Mcp(McpCommand),
|
||||
/// Show manual steps for updating Owlen to the latest revision
|
||||
Upgrade,
|
||||
}
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
enum ConfigCommand {
|
||||
/// Automatically upgrade legacy configuration values and ensure validity
|
||||
Doctor,
|
||||
/// Print the resolved configuration file path
|
||||
Path,
|
||||
}
|
||||
|
||||
async fn run_command(command: OwlenCommand) -> Result<()> {
|
||||
match command {
|
||||
OwlenCommand::Config(config_cmd) => run_config_command(config_cmd),
|
||||
OwlenCommand::Cloud(cloud_cmd) => run_cloud_command(cloud_cmd).await,
|
||||
OwlenCommand::Providers(provider_cmd) => run_providers_command(provider_cmd).await,
|
||||
OwlenCommand::Models(args) => run_models_command(args).await,
|
||||
OwlenCommand::Mcp(mcp_cmd) => run_mcp_command(mcp_cmd),
|
||||
OwlenCommand::Upgrade => {
|
||||
println!(
|
||||
"To update Owlen from source:\n git pull\n cargo install --path crates/owlen-cli --force"
|
||||
);
|
||||
println!(
|
||||
"If you installed from the AUR, use your package manager (e.g., yay -S owlen-git)."
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_config_command(command: ConfigCommand) -> Result<()> {
|
||||
match command {
|
||||
ConfigCommand::Doctor => run_config_doctor(),
|
||||
ConfigCommand::Path => {
|
||||
let path = core_config::default_config_path();
|
||||
println!("{}", path.display());
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn run_config_doctor() -> Result<()> {
|
||||
let config_path = core_config::default_config_path();
|
||||
let existed = config_path.exists();
|
||||
let mut config = config::try_load_config().unwrap_or_default();
|
||||
let _ = config.refresh_mcp_servers(None);
|
||||
let mut changes = Vec::new();
|
||||
|
||||
if !existed {
|
||||
changes.push("created configuration file from defaults".to_string());
|
||||
}
|
||||
|
||||
if config.provider(&config.general.default_provider).is_none() {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push("default provider missing; reset to 'ollama_local'".to_string());
|
||||
}
|
||||
|
||||
for key in ["ollama_local", "ollama_cloud", "openai", "anthropic"] {
|
||||
if !config.providers.contains_key(key) {
|
||||
core_config::ensure_provider_config_mut(&mut config, key);
|
||||
changes.push(format!("added default configuration for provider '{key}'"));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(entry) = config.providers.get_mut("ollama_local") {
|
||||
if entry.provider_type.trim().is_empty() || entry.provider_type != "ollama" {
|
||||
entry.provider_type = "ollama".to_string();
|
||||
changes.push("normalised providers.ollama_local.provider_type to 'ollama'".to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let mut ensure_default_enabled = true;
|
||||
|
||||
if !config.providers.values().any(|cfg| cfg.enabled) {
|
||||
let entry = core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push("no providers were enabled; enabled 'ollama_local'".to_string());
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because no providers were enabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
ensure_default_enabled = false;
|
||||
}
|
||||
|
||||
if ensure_default_enabled {
|
||||
let default_id = config.general.default_provider.clone();
|
||||
if let Some(default_cfg) = config.providers.get(&default_id) {
|
||||
if !default_cfg.enabled {
|
||||
if let Some(new_default) = config
|
||||
.providers
|
||||
.iter()
|
||||
.filter(|(id, cfg)| cfg.enabled && *id != &default_id)
|
||||
.map(|(id, _)| id.clone())
|
||||
.min()
|
||||
{
|
||||
config.general.default_provider = new_default.clone();
|
||||
changes.push(format!(
|
||||
"default provider '{default_id}' was disabled; switched default to '{new_default}'"
|
||||
));
|
||||
} else {
|
||||
let entry =
|
||||
core_config::ensure_provider_config_mut(&mut config, "ollama_local");
|
||||
if !entry.enabled {
|
||||
entry.enabled = true;
|
||||
changes.push(
|
||||
"enabled 'ollama_local' because default provider was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
if config.general.default_provider != "ollama_local" {
|
||||
config.general.default_provider = "ollama_local".to_string();
|
||||
changes.push(
|
||||
"default provider reset to 'ollama_local' because previous default was disabled"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match config.mcp.mode {
|
||||
McpMode::Legacy => {
|
||||
config.mcp.mode = McpMode::LocalOnly;
|
||||
config.mcp.warn_on_legacy = true;
|
||||
changes.push("converted [mcp].mode = 'legacy' to 'local_only'".to_string());
|
||||
}
|
||||
McpMode::RemoteOnly if config.effective_mcp_servers().is_empty() => {
|
||||
config.mcp.mode = McpMode::RemotePreferred;
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"downgraded remote-only configuration to remote_preferred because no servers are defined"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
McpMode::RemotePreferred
|
||||
if !config.mcp.allow_fallback && config.effective_mcp_servers().is_empty() =>
|
||||
{
|
||||
config.mcp.allow_fallback = true;
|
||||
changes.push(
|
||||
"enabled [mcp].allow_fallback because no remote servers are configured".to_string(),
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
config.validate()?;
|
||||
config::save_config(&config)?;
|
||||
|
||||
if changes.is_empty() {
|
||||
println!(
|
||||
"Configuration already up to date: {}",
|
||||
config_path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Updated {}:", config_path.display());
|
||||
for change in changes {
|
||||
println!(" - {change}");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tokio::main(flavor = "multi_thread")]
|
||||
async fn main() -> Result<()> {
|
||||
// Parse command-line arguments
|
||||
let Args { code, command } = Args::parse();
|
||||
if let Some(command) = command {
|
||||
return run_command(command).await;
|
||||
}
|
||||
let initial_mode = if code { Mode::Code } else { Mode::Chat };
|
||||
bootstrap::launch(initial_mode).await
|
||||
}
|
||||
@@ -1,259 +0,0 @@
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Subcommand, ValueEnum};
|
||||
use owlen_core::config::{self as core_config, Config, McpConfigScope, McpServerConfig};
|
||||
use owlen_tui::config as tui_config;
|
||||
|
||||
#[derive(Debug, Subcommand)]
|
||||
pub enum McpCommand {
|
||||
/// Add or update an MCP server in the selected scope
|
||||
Add(AddArgs),
|
||||
/// List MCP servers across scopes
|
||||
List(ListArgs),
|
||||
/// Remove an MCP server from a scope
|
||||
Remove(RemoveArgs),
|
||||
}
|
||||
|
||||
pub fn run_mcp_command(command: McpCommand) -> Result<()> {
|
||||
match command {
|
||||
McpCommand::Add(args) => handle_add(args),
|
||||
McpCommand::List(args) => handle_list(args),
|
||||
McpCommand::Remove(args) => handle_remove(args),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, ValueEnum, Default)]
|
||||
pub enum ScopeArg {
|
||||
User,
|
||||
#[default]
|
||||
Project,
|
||||
Local,
|
||||
}
|
||||
|
||||
impl From<ScopeArg> for McpConfigScope {
|
||||
fn from(value: ScopeArg) -> Self {
|
||||
match value {
|
||||
ScopeArg::User => McpConfigScope::User,
|
||||
ScopeArg::Project => McpConfigScope::Project,
|
||||
ScopeArg::Local => McpConfigScope::Local,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct AddArgs {
|
||||
/// Logical name used to reference the server
|
||||
pub name: String,
|
||||
/// Command or endpoint invoked for the server
|
||||
pub command: String,
|
||||
/// Transport mechanism (stdio, http, websocket)
|
||||
#[arg(long, default_value = "stdio")]
|
||||
pub transport: String,
|
||||
/// Configuration scope to write the server into
|
||||
#[arg(long, value_enum, default_value_t = ScopeArg::Project)]
|
||||
pub scope: ScopeArg,
|
||||
/// Environment variables (KEY=VALUE) passed to the server process
|
||||
#[arg(long = "env")]
|
||||
pub env: Vec<String>,
|
||||
/// Additional arguments appended when launching the server
|
||||
#[arg(trailing_var_arg = true, value_name = "ARG")]
|
||||
pub args: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args, Default)]
|
||||
pub struct ListArgs {
|
||||
/// Restrict output to a specific configuration scope
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
/// Display only the effective servers (after precedence resolution)
|
||||
#[arg(long)]
|
||||
pub effective_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct RemoveArgs {
|
||||
/// Name of the server to remove
|
||||
pub name: String,
|
||||
/// Optional explicit scope to remove from
|
||||
#[arg(long, value_enum)]
|
||||
pub scope: Option<ScopeArg>,
|
||||
}
|
||||
|
||||
fn handle_add(args: AddArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope: McpConfigScope = args.scope.into();
|
||||
let mut env_map = HashMap::new();
|
||||
for pair in &args.env {
|
||||
let (key, value) = pair
|
||||
.split_once('=')
|
||||
.ok_or_else(|| anyhow!("Environment pairs must use KEY=VALUE syntax: '{}'", pair))?;
|
||||
if key.trim().is_empty() {
|
||||
return Err(anyhow!("Environment variable name cannot be empty"));
|
||||
}
|
||||
env_map.insert(key.trim().to_string(), value.to_string());
|
||||
}
|
||||
|
||||
let server = McpServerConfig {
|
||||
name: args.name.clone(),
|
||||
command: args.command.clone(),
|
||||
args: args.args.clone(),
|
||||
transport: args.transport.to_lowercase(),
|
||||
env: env_map,
|
||||
oauth: None,
|
||||
};
|
||||
|
||||
config.add_mcp_server(scope, server.clone(), None)?;
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope ({})",
|
||||
server.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Registered MCP server '{}' in {} scope.",
|
||||
server.name, scope
|
||||
);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_list(args: ListArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
config.refresh_mcp_servers(None)?;
|
||||
|
||||
let scoped = config.scoped_mcp_servers();
|
||||
if scoped.is_empty() {
|
||||
println!("No MCP servers configured.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let filter_scope = args.scope.map(|scope| scope.into());
|
||||
let effective = config.effective_mcp_servers();
|
||||
let mut active = HashSet::new();
|
||||
for server in effective {
|
||||
active.insert((
|
||||
server.name.clone(),
|
||||
server.command.clone(),
|
||||
server.transport.to_lowercase(),
|
||||
));
|
||||
}
|
||||
|
||||
println!(
|
||||
"{:<2} {:<8} {:<20} {:<10} Command",
|
||||
"", "Scope", "Name", "Transport"
|
||||
);
|
||||
for entry in scoped {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let payload = format_command_line(&entry.config.command, &entry.config.args);
|
||||
let key = (
|
||||
entry.config.name.clone(),
|
||||
entry.config.command.clone(),
|
||||
entry.config.transport.to_lowercase(),
|
||||
);
|
||||
let marker = if active.contains(&key) { "*" } else { " " };
|
||||
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
println!(
|
||||
"{} {:<8} {:<20} {:<10} {}",
|
||||
marker, entry.scope, entry.config.name, entry.config.transport, payload
|
||||
);
|
||||
}
|
||||
|
||||
let scoped_resources = config.scoped_mcp_resources();
|
||||
if !scoped_resources.is_empty() {
|
||||
println!();
|
||||
println!("{:<2} {:<8} {:<30} Title", "", "Scope", "Resource");
|
||||
let effective_keys: HashSet<(String, String)> = config
|
||||
.effective_mcp_resources()
|
||||
.iter()
|
||||
.map(|res| (res.server.clone(), res.uri.clone()))
|
||||
.collect();
|
||||
|
||||
for entry in scoped_resources {
|
||||
if filter_scope
|
||||
.as_ref()
|
||||
.is_some_and(|target_scope| entry.scope != *target_scope)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let key = (entry.config.server.clone(), entry.config.uri.clone());
|
||||
let marker = if effective_keys.contains(&key) {
|
||||
"*"
|
||||
} else {
|
||||
" "
|
||||
};
|
||||
if args.effective_only && marker != "*" {
|
||||
continue;
|
||||
}
|
||||
|
||||
let reference = format!("@{}:{}", entry.config.server, entry.config.uri);
|
||||
let title = entry.config.title.as_deref().unwrap_or("—");
|
||||
|
||||
println!("{} {:<8} {:<30} {}", marker, entry.scope, reference, title);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn handle_remove(args: RemoveArgs) -> Result<()> {
|
||||
let mut config = load_config()?;
|
||||
let scope_hint = args.scope.map(|scope| scope.into());
|
||||
let result = config.remove_mcp_server(scope_hint, &args.name, None)?;
|
||||
|
||||
match result {
|
||||
Some(scope) => {
|
||||
if matches!(scope, McpConfigScope::User) {
|
||||
tui_config::save_config(&config)?;
|
||||
}
|
||||
|
||||
if let Some(path) = core_config::mcp_scope_path(scope, None) {
|
||||
println!(
|
||||
"Removed MCP server '{}' from {} scope ({})",
|
||||
args.name,
|
||||
scope,
|
||||
path.display()
|
||||
);
|
||||
} else {
|
||||
println!("Removed MCP server '{}' from {} scope.", args.name, scope);
|
||||
}
|
||||
}
|
||||
None => {
|
||||
println!("No MCP server named '{}' was found.", args.name);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn load_config() -> Result<Config> {
|
||||
let mut config = tui_config::try_load_config().unwrap_or_default();
|
||||
config.refresh_mcp_servers(None)?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn format_command_line(command: &str, args: &[String]) -> String {
|
||||
if args.is_empty() {
|
||||
command.to_string()
|
||||
} else {
|
||||
format!("{} {}", command, args.join(" "))
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user