//! LLM Provider Abstraction Layer //! //! This crate defines the common types and traits for LLM provider integration. //! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait //! to enable swapping providers at runtime. use async_trait::async_trait; use futures::Stream; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::pin::Pin; use thiserror::Error; // ============================================================================ // Public Modules // ============================================================================ pub mod retry; pub mod tokens; // Re-export token counting types for convenience pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter}; // Re-export retry types for convenience pub use retry::{is_retryable_error, RetryConfig, RetryStrategy}; // ============================================================================ // Error Types // ============================================================================ #[derive(Error, Debug)] pub enum LlmError { #[error("HTTP error: {0}")] Http(String), #[error("JSON parsing error: {0}")] Json(String), #[error("Authentication error: {0}")] Auth(String), #[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")] RateLimit { retry_after_secs: Option }, #[error("API error: {message}")] Api { message: String, code: Option }, #[error("Provider error: {0}")] Provider(String), #[error("Stream error: {0}")] Stream(String), #[error("Request timeout: {0}")] Timeout(String), } // ============================================================================ // Message Types // ============================================================================ /// Role of a message in the conversation #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum Role { System, User, Assistant, Tool, } impl Role { pub fn as_str(&self) -> &'static str { match self { Role::System => "system", Role::User => "user", Role::Assistant => "assistant", Role::Tool => "tool", } } } impl From<&str> for Role { fn from(s: &str) -> Self { match s.to_lowercase().as_str() { "system" => Role::System, "user" => Role::User, "assistant" => Role::Assistant, "tool" => Role::Tool, _ => Role::User, // Default fallback } } } /// A message in the conversation #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatMessage { pub role: Role, #[serde(skip_serializing_if = "Option::is_none")] pub content: Option, /// Tool calls made by the assistant #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, /// For tool role messages: the ID of the tool call this responds to #[serde(skip_serializing_if = "Option::is_none")] pub tool_call_id: Option, /// For tool role messages: the name of the tool #[serde(skip_serializing_if = "Option::is_none")] pub name: Option, } impl ChatMessage { /// Create a system message pub fn system(content: impl Into) -> Self { Self { role: Role::System, content: Some(content.into()), tool_calls: None, tool_call_id: None, name: None, } } /// Create a user message pub fn user(content: impl Into) -> Self { Self { role: Role::User, content: Some(content.into()), tool_calls: None, tool_call_id: None, name: None, } } /// Create an assistant message pub fn assistant(content: impl Into) -> Self { Self { role: Role::Assistant, content: Some(content.into()), tool_calls: None, tool_call_id: None, name: None, } } /// Create an assistant message with tool calls (no text content) pub fn assistant_tool_calls(tool_calls: Vec) -> Self { Self { role: Role::Assistant, content: None, tool_calls: Some(tool_calls), tool_call_id: None, name: None, } } /// Create a tool result message pub fn tool_result(tool_call_id: impl Into, content: impl Into) -> Self { Self { role: Role::Tool, content: Some(content.into()), tool_calls: None, tool_call_id: Some(tool_call_id.into()), name: None, } } } // ============================================================================ // Tool Types // ============================================================================ /// A tool call requested by the LLM #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct ToolCall { /// Unique identifier for this tool call pub id: String, /// The type of tool call (always "function" for now) #[serde(rename = "type", default = "default_function_type")] pub call_type: String, /// The function being called pub function: FunctionCall, } fn default_function_type() -> String { "function".to_string() } /// Details of a function call #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] pub struct FunctionCall { /// Name of the function to call pub name: String, /// Arguments as a JSON object pub arguments: Value, } /// Definition of a tool available to the LLM #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Tool { #[serde(rename = "type")] pub tool_type: String, pub function: ToolFunction, } impl Tool { /// Create a new function tool pub fn function( name: impl Into, description: impl Into, parameters: ToolParameters, ) -> Self { Self { tool_type: "function".to_string(), function: ToolFunction { name: name.into(), description: description.into(), parameters, }, } } } /// Function definition within a tool #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolFunction { pub name: String, pub description: String, pub parameters: ToolParameters, } /// Parameters schema for a function #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolParameters { #[serde(rename = "type")] pub param_type: String, /// JSON Schema properties object pub properties: Value, /// Required parameter names pub required: Vec, } impl ToolParameters { /// Create an object parameter schema pub fn object(properties: Value, required: Vec) -> Self { Self { param_type: "object".to_string(), properties, required, } } } // ============================================================================ // Streaming Response Types // ============================================================================ /// A chunk of a streaming response #[derive(Debug, Clone)] pub struct StreamChunk { /// Incremental text content pub content: Option, /// Tool calls (may be partial/streaming) pub tool_calls: Option>, /// Whether this is the final chunk pub done: bool, /// Usage statistics (typically only in final chunk) pub usage: Option, } /// Partial tool call for streaming #[derive(Debug, Clone)] pub struct ToolCallDelta { /// Index of this tool call in the array pub index: usize, /// Tool call ID (may only be present in first delta) pub id: Option, /// Function name (may only be present in first delta) pub function_name: Option, /// Incremental arguments string pub arguments_delta: Option, } /// Token usage statistics #[derive(Debug, Clone, Default)] pub struct Usage { pub prompt_tokens: u32, pub completion_tokens: u32, pub total_tokens: u32, } // ============================================================================ // Provider Configuration // ============================================================================ /// Options for a chat request #[derive(Debug, Clone, Default)] pub struct ChatOptions { /// Model to use pub model: String, /// Temperature (0.0 - 2.0) pub temperature: Option, /// Maximum tokens to generate pub max_tokens: Option, /// Top-p sampling pub top_p: Option, /// Stop sequences pub stop: Option>, } impl ChatOptions { pub fn new(model: impl Into) -> Self { Self { model: model.into(), ..Default::default() } } pub fn with_temperature(mut self, temp: f32) -> Self { self.temperature = Some(temp); self } pub fn with_max_tokens(mut self, max: u32) -> Self { self.max_tokens = Some(max); self } } // ============================================================================ // Provider Trait // ============================================================================ /// A boxed stream of chunks pub type ChunkStream = Pin> + Send>>; /// The main trait that all LLM providers must implement #[async_trait] pub trait LlmProvider: Send + Sync { /// Get the provider name (e.g., "ollama", "anthropic", "openai") fn name(&self) -> &str; /// Get the current model name fn model(&self) -> &str; /// Send a chat request and receive a streaming response /// /// # Arguments /// * `messages` - The conversation history /// * `options` - Request options (model, temperature, etc.) /// * `tools` - Optional list of tools the model can use /// /// # Returns /// A stream of response chunks async fn chat_stream( &self, messages: &[ChatMessage], options: &ChatOptions, tools: Option<&[Tool]>, ) -> Result; /// Send a chat request and receive a complete response (non-streaming) /// /// Default implementation collects the stream, but providers may override /// for efficiency. async fn chat( &self, messages: &[ChatMessage], options: &ChatOptions, tools: Option<&[Tool]>, ) -> Result { use futures::StreamExt; let mut stream = self.chat_stream(messages, options, tools).await?; let mut content = String::new(); let mut tool_calls: Vec = Vec::new(); let mut usage = None; while let Some(chunk) = stream.next().await { let chunk = chunk?; if let Some(text) = chunk.content { content.push_str(&text); } if let Some(deltas) = chunk.tool_calls { for delta in deltas { // Grow the tool_calls vec if needed while tool_calls.len() <= delta.index { tool_calls.push(PartialToolCall::default()); } let partial = &mut tool_calls[delta.index]; if let Some(id) = delta.id { partial.id = Some(id); } if let Some(name) = delta.function_name { partial.function_name = Some(name); } if let Some(args) = delta.arguments_delta { partial.arguments.push_str(&args); } } } if chunk.usage.is_some() { usage = chunk.usage; } } // Convert partial tool calls to complete tool calls let final_tool_calls: Vec = tool_calls .into_iter() .filter_map(|p| p.try_into_tool_call()) .collect(); Ok(ChatResponse { content: if content.is_empty() { None } else { Some(content) }, tool_calls: if final_tool_calls.is_empty() { None } else { Some(final_tool_calls) }, usage, }) } } /// A complete chat response (non-streaming) #[derive(Debug, Clone)] pub struct ChatResponse { pub content: Option, pub tool_calls: Option>, pub usage: Option, } /// Helper for accumulating streaming tool calls #[derive(Default)] struct PartialToolCall { id: Option, function_name: Option, arguments: String, } impl PartialToolCall { fn try_into_tool_call(self) -> Option { let id = self.id?; let name = self.function_name?; let arguments: Value = serde_json::from_str(&self.arguments).ok()?; Some(ToolCall { id, call_type: "function".to_string(), function: FunctionCall { name, arguments }, }) } } // ============================================================================ // Authentication // ============================================================================ /// Authentication method for LLM providers #[derive(Debug, Clone)] pub enum AuthMethod { /// No authentication (for local providers like Ollama) None, /// API key authentication ApiKey(String), /// OAuth access token (from login flow) OAuth { access_token: String, refresh_token: Option, expires_at: Option, }, } impl AuthMethod { /// Create API key auth pub fn api_key(key: impl Into) -> Self { Self::ApiKey(key.into()) } /// Create OAuth auth from tokens pub fn oauth(access_token: impl Into) -> Self { Self::OAuth { access_token: access_token.into(), refresh_token: None, expires_at: None, } } /// Create OAuth auth with refresh token pub fn oauth_with_refresh( access_token: impl Into, refresh_token: impl Into, expires_at: Option, ) -> Self { Self::OAuth { access_token: access_token.into(), refresh_token: Some(refresh_token.into()), expires_at, } } /// Get the bearer token for Authorization header pub fn bearer_token(&self) -> Option<&str> { match self { Self::None => None, Self::ApiKey(key) => Some(key), Self::OAuth { access_token, .. } => Some(access_token), } } /// Check if token might need refresh pub fn needs_refresh(&self) -> bool { match self { Self::OAuth { expires_at: Some(exp), refresh_token: Some(_), .. } => { let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .map(|d| d.as_secs()) .unwrap_or(0); // Refresh if expiring within 5 minutes *exp < now + 300 } _ => false, } } } /// Device code response for OAuth device flow #[derive(Debug, Clone, Serialize, Deserialize)] pub struct DeviceCodeResponse { /// Code the user enters on the verification page pub user_code: String, /// URL the user visits to authorize pub verification_uri: String, /// Full URL with code pre-filled (if supported) pub verification_uri_complete: Option, /// Device code for polling (internal use) pub device_code: String, /// How often to poll (in seconds) pub interval: u64, /// When the codes expire (in seconds) pub expires_in: u64, } /// Result of polling for device authorization #[derive(Debug, Clone)] pub enum DeviceAuthResult { /// Still waiting for user to authorize Pending, /// User authorized, here are the tokens Success { access_token: String, refresh_token: Option, expires_in: Option, }, /// User denied authorization Denied, /// Code expired Expired, } /// Trait for providers that support OAuth device flow #[async_trait] pub trait OAuthProvider { /// Start the device authorization flow async fn start_device_auth(&self) -> Result; /// Poll for the authorization result async fn poll_device_auth(&self, device_code: &str) -> Result; /// Refresh an access token using a refresh token async fn refresh_token(&self, refresh_token: &str) -> Result; } /// Stored credentials for a provider #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StoredCredentials { pub provider: String, pub access_token: String, pub refresh_token: Option, pub expires_at: Option, } // ============================================================================ // Provider Status & Info // ============================================================================ /// Status information for a provider connection #[derive(Debug, Clone)] pub struct ProviderStatus { /// Provider name pub provider: String, /// Whether the connection is authenticated pub authenticated: bool, /// Current user/account info if authenticated pub account: Option, /// Current model being used pub model: String, /// API endpoint URL pub endpoint: String, /// Whether the provider is reachable pub reachable: bool, /// Any status message or error pub message: Option, } /// Account/user information from the provider #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AccountInfo { /// Account/user ID pub id: Option, /// Display name or email pub name: Option, /// Account email pub email: Option, /// Account type (free, pro, team, enterprise) pub account_type: Option, /// Organization name if applicable pub organization: Option, } /// Usage statistics from the provider #[derive(Debug, Clone, Serialize, Deserialize)] pub struct UsageStats { /// Total tokens used in current period pub tokens_used: Option, /// Token limit for current period (if applicable) pub token_limit: Option, /// Number of requests made pub requests_made: Option, /// Request limit (if applicable) pub request_limit: Option, /// Cost incurred (if available) pub cost_usd: Option, /// Period start timestamp pub period_start: Option, /// Period end timestamp pub period_end: Option, } /// Available model information #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ModelInfo { /// Model ID/name pub id: String, /// Human-readable display name pub display_name: Option, /// Model description pub description: Option, /// Context window size (tokens) pub context_window: Option, /// Max output tokens pub max_output_tokens: Option, /// Whether the model supports tool use pub supports_tools: bool, /// Whether the model supports vision/images pub supports_vision: bool, /// Input token price per 1M tokens (USD) pub input_price_per_mtok: Option, /// Output token price per 1M tokens (USD) pub output_price_per_mtok: Option, } /// Trait for providers that support status/info queries #[async_trait] pub trait ProviderInfo { /// Get the current connection status async fn status(&self) -> Result; /// Get account information (if authenticated) async fn account_info(&self) -> Result, LlmError>; /// Get usage statistics (if available) async fn usage_stats(&self) -> Result, LlmError>; /// List available models async fn list_models(&self) -> Result, LlmError>; /// Check if a specific model is available async fn model_info(&self, model_id: &str) -> Result, LlmError> { let models = self.list_models().await?; Ok(models.into_iter().find(|m| m.id == model_id)) } } // ============================================================================ // Provider Factory // ============================================================================ /// Supported LLM providers #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "lowercase")] pub enum ProviderType { Ollama, Anthropic, OpenAI, } impl ProviderType { pub fn from_str(s: &str) -> Option { match s.to_lowercase().as_str() { "ollama" => Some(Self::Ollama), "anthropic" | "claude" => Some(Self::Anthropic), "openai" | "gpt" => Some(Self::OpenAI), _ => None, } } pub fn as_str(&self) -> &'static str { match self { Self::Ollama => "ollama", Self::Anthropic => "anthropic", Self::OpenAI => "openai", } } /// Default model for this provider pub fn default_model(&self) -> &'static str { match self { Self::Ollama => "qwen3:8b", Self::Anthropic => "claude-sonnet-4-20250514", Self::OpenAI => "gpt-4o", } } } impl std::fmt::Display for ProviderType { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.as_str()) } }