Files
owlen/crates/llm/core/src/lib.rs
vikingowl 10c8e2baae feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support:
- Add llm-core crate with LlmProvider trait abstraction
- Implement Anthropic Claude API client with streaming
- Implement OpenAI API client with streaming
- Add token counting with SimpleTokenCounter and ClaudeTokenCounter
- Add retry logic with exponential backoff and jitter

Borderless TUI Redesign:
- Rewrite theme system with terminal capability detection (Full/Unicode256/Basic)
- Add provider tabs component with keybind switching [1]/[2]/[3]
- Implement vim-modal input (Normal/Insert/Visual/Command modes)
- Redesign chat panel with timestamps and streaming indicators
- Add multi-provider status bar with cost tracking
- Add Nerd Font icons with graceful ASCII fallbacks
- Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark)

Advanced Agent Features:
- Add system prompt builder with configurable components
- Enhance subagent orchestration with parallel execution
- Add git integration module for safe command detection
- Add streaming tool results via channels
- Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell
- Add WebSearch with provider abstraction

Plugin System Enhancement:
- Add full agent definition parsing from YAML frontmatter
- Add skill system with progressive disclosure
- Wire plugin hooks into HookManager

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-02 17:24:14 +01:00

797 lines
22 KiB
Rust

//! LLM Provider Abstraction Layer
//!
//! This crate defines the common types and traits for LLM provider integration.
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
//! to enable swapping providers at runtime.
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::pin::Pin;
use thiserror::Error;
// ============================================================================
// Public Modules
// ============================================================================
pub mod retry;
pub mod tokens;
// Re-export token counting types for convenience
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
// Re-export retry types for convenience
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
// ============================================================================
// Error Types
// ============================================================================
#[derive(Error, Debug)]
pub enum LlmError {
#[error("HTTP error: {0}")]
Http(String),
#[error("JSON parsing error: {0}")]
Json(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
RateLimit { retry_after_secs: Option<u64> },
#[error("API error: {message}")]
Api { message: String, code: Option<String> },
#[error("Provider error: {0}")]
Provider(String),
#[error("Stream error: {0}")]
Stream(String),
#[error("Request timeout: {0}")]
Timeout(String),
}
// ============================================================================
// Message Types
// ============================================================================
/// Role of a message in the conversation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
impl From<&str> for Role {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User, // Default fallback
}
}
}
/// A message in the conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// Tool calls made by the assistant
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// For tool role messages: the ID of the tool call this responds to
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// For tool role messages: the name of the tool
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
/// Create a system message
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create a user message
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message with tool calls (no text content)
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
name: None,
}
}
/// Create a tool result message
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: None,
}
}
}
// ============================================================================
// Tool Types
// ============================================================================
/// A tool call requested by the LLM
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
/// Unique identifier for this tool call
pub id: String,
/// The type of tool call (always "function" for now)
#[serde(rename = "type", default = "default_function_type")]
pub call_type: String,
/// The function being called
pub function: FunctionCall,
}
fn default_function_type() -> String {
"function".to_string()
}
/// Details of a function call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
/// Name of the function to call
pub name: String,
/// Arguments as a JSON object
pub arguments: Value,
}
/// Definition of a tool available to the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
impl Tool {
/// Create a new function tool
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: ToolParameters,
) -> Self {
Self {
tool_type: "function".to_string(),
function: ToolFunction {
name: name.into(),
description: description.into(),
parameters,
},
}
}
}
/// Function definition within a tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
/// Parameters schema for a function
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameters {
#[serde(rename = "type")]
pub param_type: String,
/// JSON Schema properties object
pub properties: Value,
/// Required parameter names
pub required: Vec<String>,
}
impl ToolParameters {
/// Create an object parameter schema
pub fn object(properties: Value, required: Vec<String>) -> Self {
Self {
param_type: "object".to_string(),
properties,
required,
}
}
}
// ============================================================================
// Streaming Response Types
// ============================================================================
/// A chunk of a streaming response
#[derive(Debug, Clone)]
pub struct StreamChunk {
/// Incremental text content
pub content: Option<String>,
/// Tool calls (may be partial/streaming)
pub tool_calls: Option<Vec<ToolCallDelta>>,
/// Whether this is the final chunk
pub done: bool,
/// Usage statistics (typically only in final chunk)
pub usage: Option<Usage>,
}
/// Partial tool call for streaming
#[derive(Debug, Clone)]
pub struct ToolCallDelta {
/// Index of this tool call in the array
pub index: usize,
/// Tool call ID (may only be present in first delta)
pub id: Option<String>,
/// Function name (may only be present in first delta)
pub function_name: Option<String>,
/// Incremental arguments string
pub arguments_delta: Option<String>,
}
/// Token usage statistics
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ============================================================================
// Provider Configuration
// ============================================================================
/// Options for a chat request
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
/// Model to use
pub model: String,
/// Temperature (0.0 - 2.0)
pub temperature: Option<f32>,
/// Maximum tokens to generate
pub max_tokens: Option<u32>,
/// Top-p sampling
pub top_p: Option<f32>,
/// Stop sequences
pub stop: Option<Vec<String>>,
}
impl ChatOptions {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
}
// ============================================================================
// Provider Trait
// ============================================================================
/// A boxed stream of chunks
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
/// The main trait that all LLM providers must implement
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
fn name(&self) -> &str;
/// Get the current model name
fn model(&self) -> &str;
/// Send a chat request and receive a streaming response
///
/// # Arguments
/// * `messages` - The conversation history
/// * `options` - Request options (model, temperature, etc.)
/// * `tools` - Optional list of tools the model can use
///
/// # Returns
/// A stream of response chunks
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError>;
/// Send a chat request and receive a complete response (non-streaming)
///
/// Default implementation collects the stream, but providers may override
/// for efficiency.
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
use futures::StreamExt;
let mut stream = self.chat_stream(messages, options, tools).await?;
let mut content = String::new();
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
let mut usage = None;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if let Some(text) = chunk.content {
content.push_str(&text);
}
if let Some(deltas) = chunk.tool_calls {
for delta in deltas {
// Grow the tool_calls vec if needed
while tool_calls.len() <= delta.index {
tool_calls.push(PartialToolCall::default());
}
let partial = &mut tool_calls[delta.index];
if let Some(id) = delta.id {
partial.id = Some(id);
}
if let Some(name) = delta.function_name {
partial.function_name = Some(name);
}
if let Some(args) = delta.arguments_delta {
partial.arguments.push_str(&args);
}
}
}
if chunk.usage.is_some() {
usage = chunk.usage;
}
}
// Convert partial tool calls to complete tool calls
let final_tool_calls: Vec<ToolCall> = tool_calls
.into_iter()
.filter_map(|p| p.try_into_tool_call())
.collect();
Ok(ChatResponse {
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
usage,
})
}
}
/// A complete chat response (non-streaming)
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>,
}
/// Helper for accumulating streaming tool calls
#[derive(Default)]
struct PartialToolCall {
id: Option<String>,
function_name: Option<String>,
arguments: String,
}
impl PartialToolCall {
fn try_into_tool_call(self) -> Option<ToolCall> {
let id = self.id?;
let name = self.function_name?;
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
Some(ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall { name, arguments },
})
}
}
// ============================================================================
// Authentication
// ============================================================================
/// Authentication method for LLM providers
#[derive(Debug, Clone)]
pub enum AuthMethod {
/// No authentication (for local providers like Ollama)
None,
/// API key authentication
ApiKey(String),
/// OAuth access token (from login flow)
OAuth {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
},
}
impl AuthMethod {
/// Create API key auth
pub fn api_key(key: impl Into<String>) -> Self {
Self::ApiKey(key.into())
}
/// Create OAuth auth from tokens
pub fn oauth(access_token: impl Into<String>) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: None,
expires_at: None,
}
}
/// Create OAuth auth with refresh token
pub fn oauth_with_refresh(
access_token: impl Into<String>,
refresh_token: impl Into<String>,
expires_at: Option<u64>,
) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: Some(refresh_token.into()),
expires_at,
}
}
/// Get the bearer token for Authorization header
pub fn bearer_token(&self) -> Option<&str> {
match self {
Self::None => None,
Self::ApiKey(key) => Some(key),
Self::OAuth { access_token, .. } => Some(access_token),
}
}
/// Check if token might need refresh
pub fn needs_refresh(&self) -> bool {
match self {
Self::OAuth {
expires_at: Some(exp),
refresh_token: Some(_),
..
} => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
// Refresh if expiring within 5 minutes
*exp < now + 300
}
_ => false,
}
}
}
/// Device code response for OAuth device flow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceCodeResponse {
/// Code the user enters on the verification page
pub user_code: String,
/// URL the user visits to authorize
pub verification_uri: String,
/// Full URL with code pre-filled (if supported)
pub verification_uri_complete: Option<String>,
/// Device code for polling (internal use)
pub device_code: String,
/// How often to poll (in seconds)
pub interval: u64,
/// When the codes expire (in seconds)
pub expires_in: u64,
}
/// Result of polling for device authorization
#[derive(Debug, Clone)]
pub enum DeviceAuthResult {
/// Still waiting for user to authorize
Pending,
/// User authorized, here are the tokens
Success {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
},
/// User denied authorization
Denied,
/// Code expired
Expired,
}
/// Trait for providers that support OAuth device flow
#[async_trait]
pub trait OAuthProvider {
/// Start the device authorization flow
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
/// Poll for the authorization result
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
/// Refresh an access token using a refresh token
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
}
/// Stored credentials for a provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub provider: String,
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
}
// ============================================================================
// Provider Status & Info
// ============================================================================
/// Status information for a provider connection
#[derive(Debug, Clone)]
pub struct ProviderStatus {
/// Provider name
pub provider: String,
/// Whether the connection is authenticated
pub authenticated: bool,
/// Current user/account info if authenticated
pub account: Option<AccountInfo>,
/// Current model being used
pub model: String,
/// API endpoint URL
pub endpoint: String,
/// Whether the provider is reachable
pub reachable: bool,
/// Any status message or error
pub message: Option<String>,
}
/// Account/user information from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountInfo {
/// Account/user ID
pub id: Option<String>,
/// Display name or email
pub name: Option<String>,
/// Account email
pub email: Option<String>,
/// Account type (free, pro, team, enterprise)
pub account_type: Option<String>,
/// Organization name if applicable
pub organization: Option<String>,
}
/// Usage statistics from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
/// Total tokens used in current period
pub tokens_used: Option<u64>,
/// Token limit for current period (if applicable)
pub token_limit: Option<u64>,
/// Number of requests made
pub requests_made: Option<u64>,
/// Request limit (if applicable)
pub request_limit: Option<u64>,
/// Cost incurred (if available)
pub cost_usd: Option<f64>,
/// Period start timestamp
pub period_start: Option<u64>,
/// Period end timestamp
pub period_end: Option<u64>,
}
/// Available model information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
/// Model ID/name
pub id: String,
/// Human-readable display name
pub display_name: Option<String>,
/// Model description
pub description: Option<String>,
/// Context window size (tokens)
pub context_window: Option<u32>,
/// Max output tokens
pub max_output_tokens: Option<u32>,
/// Whether the model supports tool use
pub supports_tools: bool,
/// Whether the model supports vision/images
pub supports_vision: bool,
/// Input token price per 1M tokens (USD)
pub input_price_per_mtok: Option<f64>,
/// Output token price per 1M tokens (USD)
pub output_price_per_mtok: Option<f64>,
}
/// Trait for providers that support status/info queries
#[async_trait]
pub trait ProviderInfo {
/// Get the current connection status
async fn status(&self) -> Result<ProviderStatus, LlmError>;
/// Get account information (if authenticated)
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
/// Get usage statistics (if available)
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
/// List available models
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
/// Check if a specific model is available
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = self.list_models().await?;
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
// ============================================================================
// Provider Factory
// ============================================================================
/// Supported LLM providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderType {
Ollama,
Anthropic,
OpenAI,
}
impl ProviderType {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"ollama" => Some(Self::Ollama),
"anthropic" | "claude" => Some(Self::Anthropic),
"openai" | "gpt" => Some(Self::OpenAI),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Ollama => "ollama",
Self::Anthropic => "anthropic",
Self::OpenAI => "openai",
}
}
/// Default model for this provider
pub fn default_model(&self) -> &'static str {
match self {
Self::Ollama => "qwen3:8b",
Self::Anthropic => "claude-sonnet-4-20250514",
Self::OpenAI => "gpt-4o",
}
}
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}