Multi-LLM Provider Support: - Add llm-core crate with LlmProvider trait abstraction - Implement Anthropic Claude API client with streaming - Implement OpenAI API client with streaming - Add token counting with SimpleTokenCounter and ClaudeTokenCounter - Add retry logic with exponential backoff and jitter Borderless TUI Redesign: - Rewrite theme system with terminal capability detection (Full/Unicode256/Basic) - Add provider tabs component with keybind switching [1]/[2]/[3] - Implement vim-modal input (Normal/Insert/Visual/Command modes) - Redesign chat panel with timestamps and streaming indicators - Add multi-provider status bar with cost tracking - Add Nerd Font icons with graceful ASCII fallbacks - Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark) Advanced Agent Features: - Add system prompt builder with configurable components - Enhance subagent orchestration with parallel execution - Add git integration module for safe command detection - Add streaming tool results via channels - Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell - Add WebSearch with provider abstraction Plugin System Enhancement: - Add full agent definition parsing from YAML frontmatter - Add skill system with progressive disclosure - Wire plugin hooks into HookManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
797 lines
22 KiB
Rust
797 lines
22 KiB
Rust
//! LLM Provider Abstraction Layer
|
|
//!
|
|
//! This crate defines the common types and traits for LLM provider integration.
|
|
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
|
|
//! to enable swapping providers at runtime.
|
|
|
|
use async_trait::async_trait;
|
|
use futures::Stream;
|
|
use serde::{Deserialize, Serialize};
|
|
use serde_json::Value;
|
|
use std::pin::Pin;
|
|
use thiserror::Error;
|
|
|
|
// ============================================================================
|
|
// Public Modules
|
|
// ============================================================================
|
|
|
|
pub mod retry;
|
|
pub mod tokens;
|
|
|
|
// Re-export token counting types for convenience
|
|
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
|
|
|
|
// Re-export retry types for convenience
|
|
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
|
|
|
|
// ============================================================================
|
|
// Error Types
|
|
// ============================================================================
|
|
|
|
#[derive(Error, Debug)]
|
|
pub enum LlmError {
|
|
#[error("HTTP error: {0}")]
|
|
Http(String),
|
|
|
|
#[error("JSON parsing error: {0}")]
|
|
Json(String),
|
|
|
|
#[error("Authentication error: {0}")]
|
|
Auth(String),
|
|
|
|
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
|
|
RateLimit { retry_after_secs: Option<u64> },
|
|
|
|
#[error("API error: {message}")]
|
|
Api { message: String, code: Option<String> },
|
|
|
|
#[error("Provider error: {0}")]
|
|
Provider(String),
|
|
|
|
#[error("Stream error: {0}")]
|
|
Stream(String),
|
|
|
|
#[error("Request timeout: {0}")]
|
|
Timeout(String),
|
|
}
|
|
|
|
// ============================================================================
|
|
// Message Types
|
|
// ============================================================================
|
|
|
|
/// Role of a message in the conversation
|
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum Role {
|
|
System,
|
|
User,
|
|
Assistant,
|
|
Tool,
|
|
}
|
|
|
|
impl Role {
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
Role::System => "system",
|
|
Role::User => "user",
|
|
Role::Assistant => "assistant",
|
|
Role::Tool => "tool",
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<&str> for Role {
|
|
fn from(s: &str) -> Self {
|
|
match s.to_lowercase().as_str() {
|
|
"system" => Role::System,
|
|
"user" => Role::User,
|
|
"assistant" => Role::Assistant,
|
|
"tool" => Role::Tool,
|
|
_ => Role::User, // Default fallback
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A message in the conversation
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ChatMessage {
|
|
pub role: Role,
|
|
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub content: Option<String>,
|
|
|
|
/// Tool calls made by the assistant
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
|
|
/// For tool role messages: the ID of the tool call this responds to
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
|
|
/// For tool role messages: the name of the tool
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub name: Option<String>,
|
|
}
|
|
|
|
impl ChatMessage {
|
|
/// Create a system message
|
|
pub fn system(content: impl Into<String>) -> Self {
|
|
Self {
|
|
role: Role::System,
|
|
content: Some(content.into()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
/// Create a user message
|
|
pub fn user(content: impl Into<String>) -> Self {
|
|
Self {
|
|
role: Role::User,
|
|
content: Some(content.into()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
/// Create an assistant message
|
|
pub fn assistant(content: impl Into<String>) -> Self {
|
|
Self {
|
|
role: Role::Assistant,
|
|
content: Some(content.into()),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
/// Create an assistant message with tool calls (no text content)
|
|
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
|
Self {
|
|
role: Role::Assistant,
|
|
content: None,
|
|
tool_calls: Some(tool_calls),
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
/// Create a tool result message
|
|
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
|
Self {
|
|
role: Role::Tool,
|
|
content: Some(content.into()),
|
|
tool_calls: None,
|
|
tool_call_id: Some(tool_call_id.into()),
|
|
name: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Tool Types
|
|
// ============================================================================
|
|
|
|
/// A tool call requested by the LLM
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct ToolCall {
|
|
/// Unique identifier for this tool call
|
|
pub id: String,
|
|
|
|
/// The type of tool call (always "function" for now)
|
|
#[serde(rename = "type", default = "default_function_type")]
|
|
pub call_type: String,
|
|
|
|
/// The function being called
|
|
pub function: FunctionCall,
|
|
}
|
|
|
|
fn default_function_type() -> String {
|
|
"function".to_string()
|
|
}
|
|
|
|
/// Details of a function call
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
|
pub struct FunctionCall {
|
|
/// Name of the function to call
|
|
pub name: String,
|
|
|
|
/// Arguments as a JSON object
|
|
pub arguments: Value,
|
|
}
|
|
|
|
/// Definition of a tool available to the LLM
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Tool {
|
|
#[serde(rename = "type")]
|
|
pub tool_type: String,
|
|
|
|
pub function: ToolFunction,
|
|
}
|
|
|
|
impl Tool {
|
|
/// Create a new function tool
|
|
pub fn function(
|
|
name: impl Into<String>,
|
|
description: impl Into<String>,
|
|
parameters: ToolParameters,
|
|
) -> Self {
|
|
Self {
|
|
tool_type: "function".to_string(),
|
|
function: ToolFunction {
|
|
name: name.into(),
|
|
description: description.into(),
|
|
parameters,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Function definition within a tool
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolFunction {
|
|
pub name: String,
|
|
pub description: String,
|
|
pub parameters: ToolParameters,
|
|
}
|
|
|
|
/// Parameters schema for a function
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolParameters {
|
|
#[serde(rename = "type")]
|
|
pub param_type: String,
|
|
|
|
/// JSON Schema properties object
|
|
pub properties: Value,
|
|
|
|
/// Required parameter names
|
|
pub required: Vec<String>,
|
|
}
|
|
|
|
impl ToolParameters {
|
|
/// Create an object parameter schema
|
|
pub fn object(properties: Value, required: Vec<String>) -> Self {
|
|
Self {
|
|
param_type: "object".to_string(),
|
|
properties,
|
|
required,
|
|
}
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Streaming Response Types
|
|
// ============================================================================
|
|
|
|
/// A chunk of a streaming response
|
|
#[derive(Debug, Clone)]
|
|
pub struct StreamChunk {
|
|
/// Incremental text content
|
|
pub content: Option<String>,
|
|
|
|
/// Tool calls (may be partial/streaming)
|
|
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
|
|
|
/// Whether this is the final chunk
|
|
pub done: bool,
|
|
|
|
/// Usage statistics (typically only in final chunk)
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
/// Partial tool call for streaming
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolCallDelta {
|
|
/// Index of this tool call in the array
|
|
pub index: usize,
|
|
|
|
/// Tool call ID (may only be present in first delta)
|
|
pub id: Option<String>,
|
|
|
|
/// Function name (may only be present in first delta)
|
|
pub function_name: Option<String>,
|
|
|
|
/// Incremental arguments string
|
|
pub arguments_delta: Option<String>,
|
|
}
|
|
|
|
/// Token usage statistics
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct Usage {
|
|
pub prompt_tokens: u32,
|
|
pub completion_tokens: u32,
|
|
pub total_tokens: u32,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Configuration
|
|
// ============================================================================
|
|
|
|
/// Options for a chat request
|
|
#[derive(Debug, Clone, Default)]
|
|
pub struct ChatOptions {
|
|
/// Model to use
|
|
pub model: String,
|
|
|
|
/// Temperature (0.0 - 2.0)
|
|
pub temperature: Option<f32>,
|
|
|
|
/// Maximum tokens to generate
|
|
pub max_tokens: Option<u32>,
|
|
|
|
/// Top-p sampling
|
|
pub top_p: Option<f32>,
|
|
|
|
/// Stop sequences
|
|
pub stop: Option<Vec<String>>,
|
|
}
|
|
|
|
impl ChatOptions {
|
|
pub fn new(model: impl Into<String>) -> Self {
|
|
Self {
|
|
model: model.into(),
|
|
..Default::default()
|
|
}
|
|
}
|
|
|
|
pub fn with_temperature(mut self, temp: f32) -> Self {
|
|
self.temperature = Some(temp);
|
|
self
|
|
}
|
|
|
|
pub fn with_max_tokens(mut self, max: u32) -> Self {
|
|
self.max_tokens = Some(max);
|
|
self
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Trait
|
|
// ============================================================================
|
|
|
|
/// A boxed stream of chunks
|
|
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
|
|
|
|
/// The main trait that all LLM providers must implement
|
|
#[async_trait]
|
|
pub trait LlmProvider: Send + Sync {
|
|
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
|
|
fn name(&self) -> &str;
|
|
|
|
/// Get the current model name
|
|
fn model(&self) -> &str;
|
|
|
|
/// Send a chat request and receive a streaming response
|
|
///
|
|
/// # Arguments
|
|
/// * `messages` - The conversation history
|
|
/// * `options` - Request options (model, temperature, etc.)
|
|
/// * `tools` - Optional list of tools the model can use
|
|
///
|
|
/// # Returns
|
|
/// A stream of response chunks
|
|
async fn chat_stream(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
options: &ChatOptions,
|
|
tools: Option<&[Tool]>,
|
|
) -> Result<ChunkStream, LlmError>;
|
|
|
|
/// Send a chat request and receive a complete response (non-streaming)
|
|
///
|
|
/// Default implementation collects the stream, but providers may override
|
|
/// for efficiency.
|
|
async fn chat(
|
|
&self,
|
|
messages: &[ChatMessage],
|
|
options: &ChatOptions,
|
|
tools: Option<&[Tool]>,
|
|
) -> Result<ChatResponse, LlmError> {
|
|
use futures::StreamExt;
|
|
|
|
let mut stream = self.chat_stream(messages, options, tools).await?;
|
|
let mut content = String::new();
|
|
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
|
|
let mut usage = None;
|
|
|
|
while let Some(chunk) = stream.next().await {
|
|
let chunk = chunk?;
|
|
|
|
if let Some(text) = chunk.content {
|
|
content.push_str(&text);
|
|
}
|
|
|
|
if let Some(deltas) = chunk.tool_calls {
|
|
for delta in deltas {
|
|
// Grow the tool_calls vec if needed
|
|
while tool_calls.len() <= delta.index {
|
|
tool_calls.push(PartialToolCall::default());
|
|
}
|
|
|
|
let partial = &mut tool_calls[delta.index];
|
|
if let Some(id) = delta.id {
|
|
partial.id = Some(id);
|
|
}
|
|
if let Some(name) = delta.function_name {
|
|
partial.function_name = Some(name);
|
|
}
|
|
if let Some(args) = delta.arguments_delta {
|
|
partial.arguments.push_str(&args);
|
|
}
|
|
}
|
|
}
|
|
|
|
if chunk.usage.is_some() {
|
|
usage = chunk.usage;
|
|
}
|
|
}
|
|
|
|
// Convert partial tool calls to complete tool calls
|
|
let final_tool_calls: Vec<ToolCall> = tool_calls
|
|
.into_iter()
|
|
.filter_map(|p| p.try_into_tool_call())
|
|
.collect();
|
|
|
|
Ok(ChatResponse {
|
|
content: if content.is_empty() {
|
|
None
|
|
} else {
|
|
Some(content)
|
|
},
|
|
tool_calls: if final_tool_calls.is_empty() {
|
|
None
|
|
} else {
|
|
Some(final_tool_calls)
|
|
},
|
|
usage,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// A complete chat response (non-streaming)
|
|
#[derive(Debug, Clone)]
|
|
pub struct ChatResponse {
|
|
pub content: Option<String>,
|
|
pub tool_calls: Option<Vec<ToolCall>>,
|
|
pub usage: Option<Usage>,
|
|
}
|
|
|
|
/// Helper for accumulating streaming tool calls
|
|
#[derive(Default)]
|
|
struct PartialToolCall {
|
|
id: Option<String>,
|
|
function_name: Option<String>,
|
|
arguments: String,
|
|
}
|
|
|
|
impl PartialToolCall {
|
|
fn try_into_tool_call(self) -> Option<ToolCall> {
|
|
let id = self.id?;
|
|
let name = self.function_name?;
|
|
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
|
|
|
|
Some(ToolCall {
|
|
id,
|
|
call_type: "function".to_string(),
|
|
function: FunctionCall { name, arguments },
|
|
})
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Authentication
|
|
// ============================================================================
|
|
|
|
/// Authentication method for LLM providers
|
|
#[derive(Debug, Clone)]
|
|
pub enum AuthMethod {
|
|
/// No authentication (for local providers like Ollama)
|
|
None,
|
|
|
|
/// API key authentication
|
|
ApiKey(String),
|
|
|
|
/// OAuth access token (from login flow)
|
|
OAuth {
|
|
access_token: String,
|
|
refresh_token: Option<String>,
|
|
expires_at: Option<u64>,
|
|
},
|
|
}
|
|
|
|
impl AuthMethod {
|
|
/// Create API key auth
|
|
pub fn api_key(key: impl Into<String>) -> Self {
|
|
Self::ApiKey(key.into())
|
|
}
|
|
|
|
/// Create OAuth auth from tokens
|
|
pub fn oauth(access_token: impl Into<String>) -> Self {
|
|
Self::OAuth {
|
|
access_token: access_token.into(),
|
|
refresh_token: None,
|
|
expires_at: None,
|
|
}
|
|
}
|
|
|
|
/// Create OAuth auth with refresh token
|
|
pub fn oauth_with_refresh(
|
|
access_token: impl Into<String>,
|
|
refresh_token: impl Into<String>,
|
|
expires_at: Option<u64>,
|
|
) -> Self {
|
|
Self::OAuth {
|
|
access_token: access_token.into(),
|
|
refresh_token: Some(refresh_token.into()),
|
|
expires_at,
|
|
}
|
|
}
|
|
|
|
/// Get the bearer token for Authorization header
|
|
pub fn bearer_token(&self) -> Option<&str> {
|
|
match self {
|
|
Self::None => None,
|
|
Self::ApiKey(key) => Some(key),
|
|
Self::OAuth { access_token, .. } => Some(access_token),
|
|
}
|
|
}
|
|
|
|
/// Check if token might need refresh
|
|
pub fn needs_refresh(&self) -> bool {
|
|
match self {
|
|
Self::OAuth {
|
|
expires_at: Some(exp),
|
|
refresh_token: Some(_),
|
|
..
|
|
} => {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.map(|d| d.as_secs())
|
|
.unwrap_or(0);
|
|
// Refresh if expiring within 5 minutes
|
|
*exp < now + 300
|
|
}
|
|
_ => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Device code response for OAuth device flow
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct DeviceCodeResponse {
|
|
/// Code the user enters on the verification page
|
|
pub user_code: String,
|
|
|
|
/// URL the user visits to authorize
|
|
pub verification_uri: String,
|
|
|
|
/// Full URL with code pre-filled (if supported)
|
|
pub verification_uri_complete: Option<String>,
|
|
|
|
/// Device code for polling (internal use)
|
|
pub device_code: String,
|
|
|
|
/// How often to poll (in seconds)
|
|
pub interval: u64,
|
|
|
|
/// When the codes expire (in seconds)
|
|
pub expires_in: u64,
|
|
}
|
|
|
|
/// Result of polling for device authorization
|
|
#[derive(Debug, Clone)]
|
|
pub enum DeviceAuthResult {
|
|
/// Still waiting for user to authorize
|
|
Pending,
|
|
|
|
/// User authorized, here are the tokens
|
|
Success {
|
|
access_token: String,
|
|
refresh_token: Option<String>,
|
|
expires_in: Option<u64>,
|
|
},
|
|
|
|
/// User denied authorization
|
|
Denied,
|
|
|
|
/// Code expired
|
|
Expired,
|
|
}
|
|
|
|
/// Trait for providers that support OAuth device flow
|
|
#[async_trait]
|
|
pub trait OAuthProvider {
|
|
/// Start the device authorization flow
|
|
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
|
|
|
|
/// Poll for the authorization result
|
|
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
|
|
|
|
/// Refresh an access token using a refresh token
|
|
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
|
|
}
|
|
|
|
/// Stored credentials for a provider
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct StoredCredentials {
|
|
pub provider: String,
|
|
pub access_token: String,
|
|
pub refresh_token: Option<String>,
|
|
pub expires_at: Option<u64>,
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Status & Info
|
|
// ============================================================================
|
|
|
|
/// Status information for a provider connection
|
|
#[derive(Debug, Clone)]
|
|
pub struct ProviderStatus {
|
|
/// Provider name
|
|
pub provider: String,
|
|
|
|
/// Whether the connection is authenticated
|
|
pub authenticated: bool,
|
|
|
|
/// Current user/account info if authenticated
|
|
pub account: Option<AccountInfo>,
|
|
|
|
/// Current model being used
|
|
pub model: String,
|
|
|
|
/// API endpoint URL
|
|
pub endpoint: String,
|
|
|
|
/// Whether the provider is reachable
|
|
pub reachable: bool,
|
|
|
|
/// Any status message or error
|
|
pub message: Option<String>,
|
|
}
|
|
|
|
/// Account/user information from the provider
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct AccountInfo {
|
|
/// Account/user ID
|
|
pub id: Option<String>,
|
|
|
|
/// Display name or email
|
|
pub name: Option<String>,
|
|
|
|
/// Account email
|
|
pub email: Option<String>,
|
|
|
|
/// Account type (free, pro, team, enterprise)
|
|
pub account_type: Option<String>,
|
|
|
|
/// Organization name if applicable
|
|
pub organization: Option<String>,
|
|
}
|
|
|
|
/// Usage statistics from the provider
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct UsageStats {
|
|
/// Total tokens used in current period
|
|
pub tokens_used: Option<u64>,
|
|
|
|
/// Token limit for current period (if applicable)
|
|
pub token_limit: Option<u64>,
|
|
|
|
/// Number of requests made
|
|
pub requests_made: Option<u64>,
|
|
|
|
/// Request limit (if applicable)
|
|
pub request_limit: Option<u64>,
|
|
|
|
/// Cost incurred (if available)
|
|
pub cost_usd: Option<f64>,
|
|
|
|
/// Period start timestamp
|
|
pub period_start: Option<u64>,
|
|
|
|
/// Period end timestamp
|
|
pub period_end: Option<u64>,
|
|
}
|
|
|
|
/// Available model information
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ModelInfo {
|
|
/// Model ID/name
|
|
pub id: String,
|
|
|
|
/// Human-readable display name
|
|
pub display_name: Option<String>,
|
|
|
|
/// Model description
|
|
pub description: Option<String>,
|
|
|
|
/// Context window size (tokens)
|
|
pub context_window: Option<u32>,
|
|
|
|
/// Max output tokens
|
|
pub max_output_tokens: Option<u32>,
|
|
|
|
/// Whether the model supports tool use
|
|
pub supports_tools: bool,
|
|
|
|
/// Whether the model supports vision/images
|
|
pub supports_vision: bool,
|
|
|
|
/// Input token price per 1M tokens (USD)
|
|
pub input_price_per_mtok: Option<f64>,
|
|
|
|
/// Output token price per 1M tokens (USD)
|
|
pub output_price_per_mtok: Option<f64>,
|
|
}
|
|
|
|
/// Trait for providers that support status/info queries
|
|
#[async_trait]
|
|
pub trait ProviderInfo {
|
|
/// Get the current connection status
|
|
async fn status(&self) -> Result<ProviderStatus, LlmError>;
|
|
|
|
/// Get account information (if authenticated)
|
|
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
|
|
|
|
/// Get usage statistics (if available)
|
|
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
|
|
|
|
/// List available models
|
|
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
|
|
|
|
/// Check if a specific model is available
|
|
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
|
let models = self.list_models().await?;
|
|
Ok(models.into_iter().find(|m| m.id == model_id))
|
|
}
|
|
}
|
|
|
|
// ============================================================================
|
|
// Provider Factory
|
|
// ============================================================================
|
|
|
|
/// Supported LLM providers
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
|
#[serde(rename_all = "lowercase")]
|
|
pub enum ProviderType {
|
|
Ollama,
|
|
Anthropic,
|
|
OpenAI,
|
|
}
|
|
|
|
impl ProviderType {
|
|
pub fn from_str(s: &str) -> Option<Self> {
|
|
match s.to_lowercase().as_str() {
|
|
"ollama" => Some(Self::Ollama),
|
|
"anthropic" | "claude" => Some(Self::Anthropic),
|
|
"openai" | "gpt" => Some(Self::OpenAI),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
pub fn as_str(&self) -> &'static str {
|
|
match self {
|
|
Self::Ollama => "ollama",
|
|
Self::Anthropic => "anthropic",
|
|
Self::OpenAI => "openai",
|
|
}
|
|
}
|
|
|
|
/// Default model for this provider
|
|
pub fn default_model(&self) -> &'static str {
|
|
match self {
|
|
Self::Ollama => "qwen3:8b",
|
|
Self::Anthropic => "claude-sonnet-4-20250514",
|
|
Self::OpenAI => "gpt-4o",
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::fmt::Display for ProviderType {
|
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
|
write!(f, "{}", self.as_str())
|
|
}
|
|
}
|