feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features

Multi-LLM Provider Support:
- Add llm-core crate with LlmProvider trait abstraction
- Implement Anthropic Claude API client with streaming
- Implement OpenAI API client with streaming
- Add token counting with SimpleTokenCounter and ClaudeTokenCounter
- Add retry logic with exponential backoff and jitter

Borderless TUI Redesign:
- Rewrite theme system with terminal capability detection (Full/Unicode256/Basic)
- Add provider tabs component with keybind switching [1]/[2]/[3]
- Implement vim-modal input (Normal/Insert/Visual/Command modes)
- Redesign chat panel with timestamps and streaming indicators
- Add multi-provider status bar with cost tracking
- Add Nerd Font icons with graceful ASCII fallbacks
- Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark)

Advanced Agent Features:
- Add system prompt builder with configurable components
- Enhance subagent orchestration with parallel execution
- Add git integration module for safe command detection
- Add streaming tool results via channels
- Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell
- Add WebSearch with provider abstraction

Plugin System Enhancement:
- Add full agent definition parsing from YAML frontmatter
- Add skill system with progressive disclosure
- Wire plugin hooks into HookManager

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-12-02 17:24:14 +01:00
parent 09c8c9d83e
commit 10c8e2baae
67 changed files with 11444 additions and 626 deletions

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-anthropic"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "Anthropic Claude API client for Owlen"
[dependencies]
llm-core = { path = "../core" }
async-trait = "0.1"
futures = "0.3"
reqwest = { version = "0.12", features = ["json", "stream"] }
reqwest-eventsource = "0.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["sync", "time"] }
tracing = "0.1"
uuid = { version = "1.0", features = ["v4"] }

View File

@@ -0,0 +1,285 @@
//! Anthropic OAuth Authentication
//!
//! Implements device code flow for authenticating with Anthropic without API keys.
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// OAuth client for Anthropic device flow
pub struct AnthropicAuth {
http: Client,
client_id: String,
}
// Anthropic OAuth endpoints (these would be the real endpoints)
const AUTH_BASE_URL: &str = "https://console.anthropic.com";
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
const TOKEN_ENDPOINT: &str = "/oauth/token";
// Default client ID for Owlen CLI
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
impl AnthropicAuth {
/// Create a new OAuth client with the default CLI client ID
pub fn new() -> Self {
Self {
http: Client::new(),
client_id: DEFAULT_CLIENT_ID.to_string(),
}
}
/// Create with a custom client ID
pub fn with_client_id(client_id: impl Into<String>) -> Self {
Self {
http: Client::new(),
client_id: client_id.into(),
}
}
}
impl Default for AnthropicAuth {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize)]
struct DeviceCodeRequest<'a> {
client_id: &'a str,
scope: &'a str,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeApiResponse {
device_code: String,
user_code: String,
verification_uri: String,
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Serialize)]
struct TokenRequest<'a> {
client_id: &'a str,
device_code: &'a str,
grant_type: &'a str,
}
#[derive(Debug, Deserialize)]
struct TokenApiResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
error_description: Option<String>,
}
#[async_trait::async_trait]
impl OAuthProvider for AnthropicAuth {
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
let request = DeviceCodeRequest {
client_id: &self.client_id,
scope: "api:read api:write", // Request API access
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!(
"Device code request failed ({}): {}",
status, text
)));
}
let api_response: DeviceCodeApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
Ok(DeviceCodeResponse {
device_code: api_response.device_code,
user_code: api_response.user_code,
verification_uri: api_response.verification_uri,
verification_uri_complete: api_response.verification_uri_complete,
expires_in: api_response.expires_in,
interval: api_response.interval,
})
}
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
let request = TokenRequest {
client_id: &self.client_id,
device_code,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if response.status().is_success() {
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
return Ok(DeviceAuthResult::Success {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_in: token_response.expires_in,
});
}
// Parse error response
let error_response: TokenErrorResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
match error_response.error.as_str() {
"authorization_pending" => Ok(DeviceAuthResult::Pending),
"slow_down" => Ok(DeviceAuthResult::Pending), // Treat as pending, caller should slow down
"access_denied" => Ok(DeviceAuthResult::Denied),
"expired_token" => Ok(DeviceAuthResult::Expired),
_ => Err(LlmError::Auth(format!(
"Token request failed: {} - {}",
error_response.error,
error_response.error_description.unwrap_or_default()
))),
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
#[derive(Serialize)]
struct RefreshRequest<'a> {
client_id: &'a str,
refresh_token: &'a str,
grant_type: &'a str,
}
let request = RefreshRequest {
client_id: &self.client_id,
refresh_token,
grant_type: "refresh_token",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
}
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
let expires_at = token_response.expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
Ok(AuthMethod::OAuth {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at,
})
}
}
/// Helper to perform the full device auth flow with polling
pub async fn perform_device_auth<F>(
auth: &AnthropicAuth,
on_code: F,
) -> Result<AuthMethod, LlmError>
where
F: FnOnce(&DeviceCodeResponse),
{
// Start the device flow
let device_code = auth.start_device_auth().await?;
// Let caller display the code to user
on_code(&device_code);
// Poll for completion
let poll_interval = std::time::Duration::from_secs(device_code.interval);
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
loop {
if std::time::Instant::now() > deadline {
return Err(LlmError::Auth("Device code expired".to_string()));
}
tokio::time::sleep(poll_interval).await;
match auth.poll_device_auth(&device_code.device_code).await? {
DeviceAuthResult::Success {
access_token,
refresh_token,
expires_in,
} => {
let expires_at = expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
return Ok(AuthMethod::OAuth {
access_token,
refresh_token,
expires_at,
});
}
DeviceAuthResult::Pending => continue,
DeviceAuthResult::Denied => {
return Err(LlmError::Auth("Authorization denied by user".to_string()));
}
DeviceAuthResult::Expired => {
return Err(LlmError::Auth("Device code expired".to_string()));
}
}
}
}

View File

@@ -0,0 +1,577 @@
//! Anthropic Claude API Client
//!
//! Implements the Messages API with streaming support.
use crate::types::*;
use async_trait::async_trait;
use futures::StreamExt;
use llm_core::{
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, Role, StreamChunk, Tool,
ToolCall, ToolCallDelta, Usage, UsageStats,
};
use reqwest::Client;
use reqwest_eventsource::{Event, EventSource};
use std::sync::Arc;
use tokio::sync::Mutex;
const API_BASE_URL: &str = "https://api.anthropic.com";
const MESSAGES_ENDPOINT: &str = "/v1/messages";
const API_VERSION: &str = "2023-06-01";
const DEFAULT_MAX_TOKENS: u32 = 8192;
/// Anthropic Claude API client
pub struct AnthropicClient {
http: Client,
auth: AuthMethod,
model: String,
}
impl AnthropicClient {
/// Create a new client with API key authentication
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::api_key(api_key),
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Create a new client with OAuth token
pub fn with_oauth(access_token: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::oauth(access_token),
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Create a new client with full AuthMethod
pub fn with_auth(auth: AuthMethod) -> Self {
Self {
http: Client::new(),
auth,
model: "claude-sonnet-4-20250514".to_string(),
}
}
/// Set the model to use
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
/// Get current auth method (for token refresh)
pub fn auth(&self) -> &AuthMethod {
&self.auth
}
/// Update the auth method (after refresh)
pub fn set_auth(&mut self, auth: AuthMethod) {
self.auth = auth;
}
/// Convert messages to Anthropic format, extracting system message
fn prepare_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<AnthropicMessage>) {
let mut system_content = None;
let mut anthropic_messages = Vec::new();
for msg in messages {
if msg.role == Role::System {
// Collect system messages
if let Some(content) = &msg.content {
if let Some(existing) = &mut system_content {
*existing = format!("{}\n\n{}", existing, content);
} else {
system_content = Some(content.clone());
}
}
} else {
anthropic_messages.push(AnthropicMessage::from(msg));
}
}
(system_content, anthropic_messages)
}
/// Convert tools to Anthropic format
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<AnthropicTool>> {
tools.map(|t| t.iter().map(AnthropicTool::from).collect())
}
}
#[async_trait]
impl LlmProvider for AnthropicClient {
fn name(&self) -> &str {
"anthropic"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError> {
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let (system, anthropic_messages) = Self::prepare_messages(messages);
let anthropic_tools = Self::prepare_tools(tools);
let request = MessagesRequest {
model,
messages: anthropic_messages,
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
system: system.as_deref(),
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop.as_deref(),
tools: anthropic_tools,
stream: true,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
// Build the SSE request
let req = self
.http
.post(&url)
.header("x-api-key", bearer)
.header("anthropic-version", API_VERSION)
.header("content-type", "application/json")
.json(&request);
let es = EventSource::new(req).map_err(|e| LlmError::Http(e.to_string()))?;
// State for accumulating tool calls across deltas
let tool_state: Arc<Mutex<Vec<PartialToolCall>>> = Arc::new(Mutex::new(Vec::new()));
let stream = es.filter_map(move |event| {
let tool_state = Arc::clone(&tool_state);
async move {
match event {
Ok(Event::Open) => None,
Ok(Event::Message(msg)) => {
// Parse the SSE data as JSON
let event: StreamEvent = match serde_json::from_str(&msg.data) {
Ok(e) => e,
Err(e) => {
tracing::warn!("Failed to parse SSE event: {}", e);
return None;
}
};
convert_stream_event(event, &tool_state).await
}
Err(reqwest_eventsource::Error::StreamEnded) => None,
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
}
});
Ok(Box::pin(stream))
}
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let (system, anthropic_messages) = Self::prepare_messages(messages);
let anthropic_tools = Self::prepare_tools(tools);
let request = MessagesRequest {
model,
messages: anthropic_messages,
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
system: system.as_deref(),
temperature: options.temperature,
top_p: options.top_p,
stop_sequences: options.stop.as_deref(),
tools: anthropic_tools,
stream: false,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("x-api-key", bearer)
.header("anthropic-version", API_VERSION)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
// Check for rate limiting
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
let api_response: MessagesResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Convert response to common format
let mut content = String::new();
let mut tool_calls = Vec::new();
for block in api_response.content {
match block {
ResponseContentBlock::Text { text } => {
content.push_str(&text);
}
ResponseContentBlock::ToolUse { id, name, input } => {
tool_calls.push(ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall {
name,
arguments: input,
},
});
}
}
}
let usage = api_response.usage.map(|u| Usage {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens,
});
Ok(ChatResponse {
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
usage,
})
}
}
/// Helper struct for accumulating streaming tool calls
#[derive(Default)]
struct PartialToolCall {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
name: String,
input_json: String,
}
/// Convert an Anthropic stream event to our common StreamChunk format
async fn convert_stream_event(
event: StreamEvent,
tool_state: &Arc<Mutex<Vec<PartialToolCall>>>,
) -> Option<Result<StreamChunk, LlmError>> {
match event {
StreamEvent::ContentBlockStart {
index,
content_block,
} => {
match content_block {
ContentBlockStartInfo::Text { text } => {
if text.is_empty() {
None
} else {
Some(Ok(StreamChunk {
content: Some(text),
tool_calls: None,
done: false,
usage: None,
}))
}
}
ContentBlockStartInfo::ToolUse { id, name } => {
// Store the tool call start
let mut state = tool_state.lock().await;
while state.len() <= index {
state.push(PartialToolCall::default());
}
state[index] = PartialToolCall {
id: id.clone(),
name: name.clone(),
input_json: String::new(),
};
Some(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: Some(id),
function_name: Some(name),
arguments_delta: None,
}]),
done: false,
usage: None,
}))
}
}
}
StreamEvent::ContentBlockDelta { index, delta } => match delta {
ContentDelta::TextDelta { text } => Some(Ok(StreamChunk {
content: Some(text),
tool_calls: None,
done: false,
usage: None,
})),
ContentDelta::InputJsonDelta { partial_json } => {
// Accumulate the JSON
let mut state = tool_state.lock().await;
if index < state.len() {
state[index].input_json.push_str(&partial_json);
}
Some(Ok(StreamChunk {
content: None,
tool_calls: Some(vec![ToolCallDelta {
index,
id: None,
function_name: None,
arguments_delta: Some(partial_json),
}]),
done: false,
usage: None,
}))
}
},
StreamEvent::MessageDelta { usage, .. } => {
let u = usage.map(|u| Usage {
prompt_tokens: u.input_tokens,
completion_tokens: u.output_tokens,
total_tokens: u.input_tokens + u.output_tokens,
});
Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: false,
usage: u,
}))
}
StreamEvent::MessageStop => Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
})),
StreamEvent::Error { error } => Some(Err(LlmError::Api {
message: error.message,
code: Some(error.error_type),
})),
// Ignore other events
StreamEvent::MessageStart { .. }
| StreamEvent::ContentBlockStop { .. }
| StreamEvent::Ping => None,
}
}
// ============================================================================
// ProviderInfo Implementation
// ============================================================================
/// Known Claude models with their specifications
fn get_claude_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "claude-opus-4-20250514".to_string(),
display_name: Some("Claude Opus 4".to_string()),
description: Some("Most capable model for complex tasks".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(32_000),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(15.0),
output_price_per_mtok: Some(75.0),
},
ModelInfo {
id: "claude-sonnet-4-20250514".to_string(),
display_name: Some("Claude Sonnet 4".to_string()),
description: Some("Best balance of performance and speed".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(64_000),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(3.0),
output_price_per_mtok: Some(15.0),
},
ModelInfo {
id: "claude-haiku-3-5-20241022".to_string(),
display_name: Some("Claude 3.5 Haiku".to_string()),
description: Some("Fast and affordable for simple tasks".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(8_192),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(0.80),
output_price_per_mtok: Some(4.0),
},
]
}
#[async_trait]
impl ProviderInfo for AnthropicClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
let authenticated = self.auth.bearer_token().is_some();
// Try to reach the API with a simple request
let reachable = if authenticated {
// Test with a minimal message to verify auth works
let test_messages = vec![ChatMessage::user("Hi")];
let test_opts = ChatOptions::new(&self.model).with_max_tokens(1);
match self.chat(&test_messages, &test_opts, None).await {
Ok(_) => true,
Err(LlmError::Auth(_)) => false, // Auth failed
Err(_) => true, // Other errors mean API is reachable
}
} else {
false
};
let account = if authenticated && reachable {
self.account_info().await.ok().flatten()
} else {
None
};
let message = if !authenticated {
Some("Not authenticated - run 'owlen login anthropic' to authenticate".to_string())
} else if !reachable {
Some("Cannot reach Anthropic API".to_string())
} else {
Some("Connected".to_string())
};
Ok(ProviderStatus {
provider: "anthropic".to_string(),
authenticated,
account,
model: self.model.clone(),
endpoint: API_BASE_URL.to_string(),
reachable,
message,
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// Anthropic doesn't have a public account info endpoint
// Return None - account info would come from OAuth token claims
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// Anthropic doesn't expose usage stats via API
// This would require the admin/billing API with different auth
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
// Return known models - Anthropic doesn't have a models list endpoint
Ok(get_claude_models())
}
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = get_claude_models();
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_core::ToolParameters;
use serde_json::json;
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let (system, anthropic_msgs) = AnthropicClient::prepare_messages(&messages);
assert_eq!(system, Some("You are helpful".to_string()));
assert_eq!(anthropic_msgs.len(), 2);
assert_eq!(anthropic_msgs[0].role, "user");
assert_eq!(anthropic_msgs[1].role, "assistant");
}
#[test]
fn test_tool_conversion() {
let tools = vec![Tool::function(
"read_file",
"Read a file's contents",
ToolParameters::object(
json!({
"path": {
"type": "string",
"description": "File path"
}
}),
vec!["path".to_string()],
),
)];
let anthropic_tools = AnthropicClient::prepare_tools(Some(&tools)).unwrap();
assert_eq!(anthropic_tools.len(), 1);
assert_eq!(anthropic_tools[0].name, "read_file");
assert_eq!(anthropic_tools[0].description, "Read a file's contents");
}
}

View File

@@ -0,0 +1,12 @@
//! Anthropic Claude API Client
//!
//! Implements the LlmProvider trait for Anthropic's Claude models.
//! Supports both API key authentication and OAuth device flow.
mod auth;
mod client;
mod types;
pub use auth::*;
pub use client::*;
pub use types::*;

View File

@@ -0,0 +1,276 @@
//! Anthropic API request/response types
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============================================================================
// Request Types
// ============================================================================
#[derive(Debug, Serialize)]
pub struct MessagesRequest<'a> {
pub model: &'a str,
pub messages: Vec<AnthropicMessage>,
pub max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub system: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop_sequences: Option<&'a [String]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<AnthropicTool>>,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicMessage {
pub role: String, // "user" or "assistant"
pub content: AnthropicContent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
is_error: Option<bool>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AnthropicTool {
pub name: String,
pub description: String,
pub input_schema: ToolInputSchema,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolInputSchema {
#[serde(rename = "type")]
pub schema_type: String,
pub properties: Value,
pub required: Vec<String>,
}
// ============================================================================
// Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String,
pub role: String,
pub content: Vec<ResponseContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ResponseContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: Value,
},
}
#[derive(Debug, Clone, Deserialize)]
pub struct UsageInfo {
pub input_tokens: u32,
pub output_tokens: u32,
}
// ============================================================================
// Streaming Event Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum StreamEvent {
#[serde(rename = "message_start")]
MessageStart { message: MessageStartInfo },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: ContentBlockStartInfo,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: ContentDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta {
delta: MessageDeltaInfo,
usage: Option<UsageInfo>,
},
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: ApiError },
}
#[derive(Debug, Clone, Deserialize)]
pub struct MessageStartInfo {
pub id: String,
#[serde(rename = "type")]
pub message_type: String,
pub role: String,
pub model: String,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ContentBlockStartInfo {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse { id: String, name: String },
}
#[derive(Debug, Clone, Deserialize)]
#[serde(tag = "type")]
pub enum ContentDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Debug, Clone, Deserialize)]
pub struct MessageDeltaInfo {
pub stop_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiError {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
// ============================================================================
// Conversions
// ============================================================================
impl From<&llm_core::Tool> for AnthropicTool {
fn from(tool: &llm_core::Tool) -> Self {
Self {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
input_schema: ToolInputSchema {
schema_type: tool.function.parameters.param_type.clone(),
properties: tool.function.parameters.properties.clone(),
required: tool.function.parameters.required.clone(),
},
}
}
}
impl From<&llm_core::ChatMessage> for AnthropicMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
use llm_core::Role;
let role = match msg.role {
Role::User | Role::System => "user",
Role::Assistant => "assistant",
Role::Tool => "user", // Tool results come as user messages in Anthropic
};
// Handle tool results
if msg.role == Role::Tool {
if let (Some(tool_call_id), Some(content)) = (&msg.tool_call_id, &msg.content) {
return Self {
role: "user".to_string(),
content: AnthropicContent::Blocks(vec![ContentBlock::ToolResult {
tool_use_id: tool_call_id.clone(),
content: content.clone(),
is_error: None,
}]),
};
}
}
// Handle assistant messages with tool calls
if msg.role == Role::Assistant {
if let Some(tool_calls) = &msg.tool_calls {
let mut blocks: Vec<ContentBlock> = Vec::new();
// Add text content if present
if let Some(text) = &msg.content {
if !text.is_empty() {
blocks.push(ContentBlock::Text { text: text.clone() });
}
}
// Add tool use blocks
for call in tool_calls {
blocks.push(ContentBlock::ToolUse {
id: call.id.clone(),
name: call.function.name.clone(),
input: call.function.arguments.clone(),
});
}
return Self {
role: "assistant".to_string(),
content: AnthropicContent::Blocks(blocks),
};
}
}
// Simple text message
Self {
role: role.to_string(),
content: AnthropicContent::Text(msg.content.clone().unwrap_or_default()),
}
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-core"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "LLM provider abstraction layer for Owlen"
[dependencies]
async-trait = "0.1"
futures = "0.3"
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
thiserror = "2.0"
tokio = { version = "1.0", features = ["time"] }
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt"] }

View File

@@ -0,0 +1,195 @@
//! Token counting example
//!
//! This example demonstrates how to use the token counting utilities
//! to manage LLM context windows.
//!
//! Run with: cargo run --example token_counting -p llm-core
use llm_core::{
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
};
fn main() {
println!("=== Token Counting Example ===\n");
// Example 1: Basic token counting with SimpleTokenCounter
println!("1. Basic Token Counting");
println!("{}", "-".repeat(50));
let simple_counter = SimpleTokenCounter::new(8192);
let text = "The quick brown fox jumps over the lazy dog.";
let token_count = simple_counter.count(text);
println!("Text: \"{}\"", text);
println!("Estimated tokens: {}", token_count);
println!("Max context: {}\n", simple_counter.max_context());
// Example 2: Counting tokens in chat messages
println!("2. Counting Tokens in Chat Messages");
println!("{}", "-".repeat(50));
let messages = vec![
ChatMessage::system("You are a helpful assistant that provides concise answers."),
ChatMessage::user("What is the capital of France?"),
ChatMessage::assistant("The capital of France is Paris."),
ChatMessage::user("What is its population?"),
];
let total_tokens = simple_counter.count_messages(&messages);
println!("Number of messages: {}", messages.len());
println!("Total tokens (with overhead): {}\n", total_tokens);
// Example 3: Using ClaudeTokenCounter for Claude models
println!("3. Claude-Specific Token Counting");
println!("{}", "-".repeat(50));
let claude_counter = ClaudeTokenCounter::new();
let claude_total = claude_counter.count_messages(&messages);
println!("Claude counter max context: {}", claude_counter.max_context());
println!("Claude estimated tokens: {}\n", claude_total);
// Example 4: Context window management
println!("4. Context Window Management");
println!("{}", "-".repeat(50));
let mut context = ContextWindow::new(8192);
println!("Created context window with max: {} tokens", context.max());
// Simulate adding messages
let conversation = vec![
ChatMessage::user("Tell me about Rust programming."),
ChatMessage::assistant(
"Rust is a systems programming language focused on safety, \
speed, and concurrency. It prevents common bugs like null pointer \
dereferences and data races through its ownership system.",
),
ChatMessage::user("What are its main features?"),
ChatMessage::assistant(
"Rust's main features include: 1) Memory safety without garbage collection, \
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
5) Type inference, and 6) A powerful macro system.",
),
];
for (i, msg) in conversation.iter().enumerate() {
let tokens = simple_counter.count_messages(&[msg.clone()]);
context.add_tokens(tokens);
let role = msg.role.as_str();
let preview = msg
.content
.as_ref()
.map(|c| {
if c.len() > 50 {
format!("{}...", &c[..50])
} else {
c.clone()
}
})
.unwrap_or_default();
println!(
"Message {}: [{}] \"{}\"",
i + 1,
role,
preview
);
println!(" Added {} tokens", tokens);
println!(" Total used: {} / {}", context.used(), context.max());
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
println!(" Progress: {}\n", context.progress_bar(30));
}
// Example 5: Checking context limits
println!("5. Checking Context Limits");
println!("{}", "-".repeat(50));
if context.is_near_limit(0.8) {
println!("Warning: Context is over 80% full!");
} else {
println!("Context usage is below 80%");
}
let remaining = context.remaining();
println!("Remaining tokens: {}", remaining);
let new_message_tokens = 500;
if context.has_room_for(new_message_tokens) {
println!(
"Can fit a message of {} tokens",
new_message_tokens
);
} else {
println!(
"Cannot fit a message of {} tokens - would need to compact or start new context",
new_message_tokens
);
}
// Example 6: Different counter variants
println!("\n6. Using Different Counter Variants");
println!("{}", "-".repeat(50));
let counter_8k = SimpleTokenCounter::default_8k();
let counter_32k = SimpleTokenCounter::with_32k();
let counter_128k = SimpleTokenCounter::with_128k();
println!("8k context counter: {} tokens", counter_8k.max_context());
println!("32k context counter: {} tokens", counter_32k.max_context());
println!("128k context counter: {} tokens", counter_128k.max_context());
let haiku = ClaudeTokenCounter::haiku();
let sonnet = ClaudeTokenCounter::sonnet();
let opus = ClaudeTokenCounter::opus();
println!("\nClaude Haiku: {} tokens", haiku.max_context());
println!("Claude Sonnet: {} tokens", sonnet.max_context());
println!("Claude Opus: {} tokens", opus.max_context());
// Example 7: Managing context for a long conversation
println!("\n7. Long Conversation Simulation");
println!("{}", "-".repeat(50));
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
let counter = SimpleTokenCounter::new(4096);
let mut message_count = 0;
let mut compaction_count = 0;
// Simulate 20 exchanges
for i in 0..20 {
let user_msg = ChatMessage::user(format!(
"This is user message number {} asking a question.",
i + 1
));
let assistant_msg = ChatMessage::assistant(format!(
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
i + 1
));
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
if !long_context.has_room_for(tokens_needed) {
println!(
"After {} messages, context is full ({}%). Compacting...",
message_count,
(long_context.usage_percent() * 100.0) as u32
);
// In a real scenario, we would compact the conversation
// For now, just reset
long_context.reset();
compaction_count += 1;
}
long_context.add_tokens(tokens_needed);
message_count += 2;
}
println!("Total messages: {}", message_count);
println!("Compactions needed: {}", compaction_count);
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
println!("Final progress: {}", long_context.progress_bar(40));
println!("\n=== Example Complete ===");
}

796
crates/llm/core/src/lib.rs Normal file
View File

@@ -0,0 +1,796 @@
//! LLM Provider Abstraction Layer
//!
//! This crate defines the common types and traits for LLM provider integration.
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
//! to enable swapping providers at runtime.
use async_trait::async_trait;
use futures::Stream;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::pin::Pin;
use thiserror::Error;
// ============================================================================
// Public Modules
// ============================================================================
pub mod retry;
pub mod tokens;
// Re-export token counting types for convenience
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
// Re-export retry types for convenience
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
// ============================================================================
// Error Types
// ============================================================================
#[derive(Error, Debug)]
pub enum LlmError {
#[error("HTTP error: {0}")]
Http(String),
#[error("JSON parsing error: {0}")]
Json(String),
#[error("Authentication error: {0}")]
Auth(String),
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
RateLimit { retry_after_secs: Option<u64> },
#[error("API error: {message}")]
Api { message: String, code: Option<String> },
#[error("Provider error: {0}")]
Provider(String),
#[error("Stream error: {0}")]
Stream(String),
#[error("Request timeout: {0}")]
Timeout(String),
}
// ============================================================================
// Message Types
// ============================================================================
/// Role of a message in the conversation
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Role {
System,
User,
Assistant,
Tool,
}
impl Role {
pub fn as_str(&self) -> &'static str {
match self {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
}
}
}
impl From<&str> for Role {
fn from(s: &str) -> Self {
match s.to_lowercase().as_str() {
"system" => Role::System,
"user" => Role::User,
"assistant" => Role::Assistant,
"tool" => Role::Tool,
_ => Role::User, // Default fallback
}
}
}
/// A message in the conversation
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: Role,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
/// Tool calls made by the assistant
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
/// For tool role messages: the ID of the tool call this responds to
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// For tool role messages: the name of the tool
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
/// Create a system message
pub fn system(content: impl Into<String>) -> Self {
Self {
role: Role::System,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create a user message
pub fn user(content: impl Into<String>) -> Self {
Self {
role: Role::User,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: Role::Assistant,
content: Some(content.into()),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
/// Create an assistant message with tool calls (no text content)
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
Self {
role: Role::Assistant,
content: None,
tool_calls: Some(tool_calls),
tool_call_id: None,
name: None,
}
}
/// Create a tool result message
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
role: Role::Tool,
content: Some(content.into()),
tool_calls: None,
tool_call_id: Some(tool_call_id.into()),
name: None,
}
}
}
// ============================================================================
// Tool Types
// ============================================================================
/// A tool call requested by the LLM
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolCall {
/// Unique identifier for this tool call
pub id: String,
/// The type of tool call (always "function" for now)
#[serde(rename = "type", default = "default_function_type")]
pub call_type: String,
/// The function being called
pub function: FunctionCall,
}
fn default_function_type() -> String {
"function".to_string()
}
/// Details of a function call
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct FunctionCall {
/// Name of the function to call
pub name: String,
/// Arguments as a JSON object
pub arguments: Value,
}
/// Definition of a tool available to the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: ToolFunction,
}
impl Tool {
/// Create a new function tool
pub fn function(
name: impl Into<String>,
description: impl Into<String>,
parameters: ToolParameters,
) -> Self {
Self {
tool_type: "function".to_string(),
function: ToolFunction {
name: name.into(),
description: description.into(),
parameters,
},
}
}
}
/// Function definition within a tool
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolFunction {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
/// Parameters schema for a function
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameters {
#[serde(rename = "type")]
pub param_type: String,
/// JSON Schema properties object
pub properties: Value,
/// Required parameter names
pub required: Vec<String>,
}
impl ToolParameters {
/// Create an object parameter schema
pub fn object(properties: Value, required: Vec<String>) -> Self {
Self {
param_type: "object".to_string(),
properties,
required,
}
}
}
// ============================================================================
// Streaming Response Types
// ============================================================================
/// A chunk of a streaming response
#[derive(Debug, Clone)]
pub struct StreamChunk {
/// Incremental text content
pub content: Option<String>,
/// Tool calls (may be partial/streaming)
pub tool_calls: Option<Vec<ToolCallDelta>>,
/// Whether this is the final chunk
pub done: bool,
/// Usage statistics (typically only in final chunk)
pub usage: Option<Usage>,
}
/// Partial tool call for streaming
#[derive(Debug, Clone)]
pub struct ToolCallDelta {
/// Index of this tool call in the array
pub index: usize,
/// Tool call ID (may only be present in first delta)
pub id: Option<String>,
/// Function name (may only be present in first delta)
pub function_name: Option<String>,
/// Incremental arguments string
pub arguments_delta: Option<String>,
}
/// Token usage statistics
#[derive(Debug, Clone, Default)]
pub struct Usage {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ============================================================================
// Provider Configuration
// ============================================================================
/// Options for a chat request
#[derive(Debug, Clone, Default)]
pub struct ChatOptions {
/// Model to use
pub model: String,
/// Temperature (0.0 - 2.0)
pub temperature: Option<f32>,
/// Maximum tokens to generate
pub max_tokens: Option<u32>,
/// Top-p sampling
pub top_p: Option<f32>,
/// Stop sequences
pub stop: Option<Vec<String>>,
}
impl ChatOptions {
pub fn new(model: impl Into<String>) -> Self {
Self {
model: model.into(),
..Default::default()
}
}
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
}
// ============================================================================
// Provider Trait
// ============================================================================
/// A boxed stream of chunks
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
/// The main trait that all LLM providers must implement
#[async_trait]
pub trait LlmProvider: Send + Sync {
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
fn name(&self) -> &str;
/// Get the current model name
fn model(&self) -> &str;
/// Send a chat request and receive a streaming response
///
/// # Arguments
/// * `messages` - The conversation history
/// * `options` - Request options (model, temperature, etc.)
/// * `tools` - Optional list of tools the model can use
///
/// # Returns
/// A stream of response chunks
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError>;
/// Send a chat request and receive a complete response (non-streaming)
///
/// Default implementation collects the stream, but providers may override
/// for efficiency.
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
use futures::StreamExt;
let mut stream = self.chat_stream(messages, options, tools).await?;
let mut content = String::new();
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
let mut usage = None;
while let Some(chunk) = stream.next().await {
let chunk = chunk?;
if let Some(text) = chunk.content {
content.push_str(&text);
}
if let Some(deltas) = chunk.tool_calls {
for delta in deltas {
// Grow the tool_calls vec if needed
while tool_calls.len() <= delta.index {
tool_calls.push(PartialToolCall::default());
}
let partial = &mut tool_calls[delta.index];
if let Some(id) = delta.id {
partial.id = Some(id);
}
if let Some(name) = delta.function_name {
partial.function_name = Some(name);
}
if let Some(args) = delta.arguments_delta {
partial.arguments.push_str(&args);
}
}
}
if chunk.usage.is_some() {
usage = chunk.usage;
}
}
// Convert partial tool calls to complete tool calls
let final_tool_calls: Vec<ToolCall> = tool_calls
.into_iter()
.filter_map(|p| p.try_into_tool_call())
.collect();
Ok(ChatResponse {
content: if content.is_empty() {
None
} else {
Some(content)
},
tool_calls: if final_tool_calls.is_empty() {
None
} else {
Some(final_tool_calls)
},
usage,
})
}
}
/// A complete chat response (non-streaming)
#[derive(Debug, Clone)]
pub struct ChatResponse {
pub content: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
pub usage: Option<Usage>,
}
/// Helper for accumulating streaming tool calls
#[derive(Default)]
struct PartialToolCall {
id: Option<String>,
function_name: Option<String>,
arguments: String,
}
impl PartialToolCall {
fn try_into_tool_call(self) -> Option<ToolCall> {
let id = self.id?;
let name = self.function_name?;
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
Some(ToolCall {
id,
call_type: "function".to_string(),
function: FunctionCall { name, arguments },
})
}
}
// ============================================================================
// Authentication
// ============================================================================
/// Authentication method for LLM providers
#[derive(Debug, Clone)]
pub enum AuthMethod {
/// No authentication (for local providers like Ollama)
None,
/// API key authentication
ApiKey(String),
/// OAuth access token (from login flow)
OAuth {
access_token: String,
refresh_token: Option<String>,
expires_at: Option<u64>,
},
}
impl AuthMethod {
/// Create API key auth
pub fn api_key(key: impl Into<String>) -> Self {
Self::ApiKey(key.into())
}
/// Create OAuth auth from tokens
pub fn oauth(access_token: impl Into<String>) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: None,
expires_at: None,
}
}
/// Create OAuth auth with refresh token
pub fn oauth_with_refresh(
access_token: impl Into<String>,
refresh_token: impl Into<String>,
expires_at: Option<u64>,
) -> Self {
Self::OAuth {
access_token: access_token.into(),
refresh_token: Some(refresh_token.into()),
expires_at,
}
}
/// Get the bearer token for Authorization header
pub fn bearer_token(&self) -> Option<&str> {
match self {
Self::None => None,
Self::ApiKey(key) => Some(key),
Self::OAuth { access_token, .. } => Some(access_token),
}
}
/// Check if token might need refresh
pub fn needs_refresh(&self) -> bool {
match self {
Self::OAuth {
expires_at: Some(exp),
refresh_token: Some(_),
..
} => {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
// Refresh if expiring within 5 minutes
*exp < now + 300
}
_ => false,
}
}
}
/// Device code response for OAuth device flow
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceCodeResponse {
/// Code the user enters on the verification page
pub user_code: String,
/// URL the user visits to authorize
pub verification_uri: String,
/// Full URL with code pre-filled (if supported)
pub verification_uri_complete: Option<String>,
/// Device code for polling (internal use)
pub device_code: String,
/// How often to poll (in seconds)
pub interval: u64,
/// When the codes expire (in seconds)
pub expires_in: u64,
}
/// Result of polling for device authorization
#[derive(Debug, Clone)]
pub enum DeviceAuthResult {
/// Still waiting for user to authorize
Pending,
/// User authorized, here are the tokens
Success {
access_token: String,
refresh_token: Option<String>,
expires_in: Option<u64>,
},
/// User denied authorization
Denied,
/// Code expired
Expired,
}
/// Trait for providers that support OAuth device flow
#[async_trait]
pub trait OAuthProvider {
/// Start the device authorization flow
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
/// Poll for the authorization result
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
/// Refresh an access token using a refresh token
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
}
/// Stored credentials for a provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StoredCredentials {
pub provider: String,
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<u64>,
}
// ============================================================================
// Provider Status & Info
// ============================================================================
/// Status information for a provider connection
#[derive(Debug, Clone)]
pub struct ProviderStatus {
/// Provider name
pub provider: String,
/// Whether the connection is authenticated
pub authenticated: bool,
/// Current user/account info if authenticated
pub account: Option<AccountInfo>,
/// Current model being used
pub model: String,
/// API endpoint URL
pub endpoint: String,
/// Whether the provider is reachable
pub reachable: bool,
/// Any status message or error
pub message: Option<String>,
}
/// Account/user information from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AccountInfo {
/// Account/user ID
pub id: Option<String>,
/// Display name or email
pub name: Option<String>,
/// Account email
pub email: Option<String>,
/// Account type (free, pro, team, enterprise)
pub account_type: Option<String>,
/// Organization name if applicable
pub organization: Option<String>,
}
/// Usage statistics from the provider
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UsageStats {
/// Total tokens used in current period
pub tokens_used: Option<u64>,
/// Token limit for current period (if applicable)
pub token_limit: Option<u64>,
/// Number of requests made
pub requests_made: Option<u64>,
/// Request limit (if applicable)
pub request_limit: Option<u64>,
/// Cost incurred (if available)
pub cost_usd: Option<f64>,
/// Period start timestamp
pub period_start: Option<u64>,
/// Period end timestamp
pub period_end: Option<u64>,
}
/// Available model information
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
/// Model ID/name
pub id: String,
/// Human-readable display name
pub display_name: Option<String>,
/// Model description
pub description: Option<String>,
/// Context window size (tokens)
pub context_window: Option<u32>,
/// Max output tokens
pub max_output_tokens: Option<u32>,
/// Whether the model supports tool use
pub supports_tools: bool,
/// Whether the model supports vision/images
pub supports_vision: bool,
/// Input token price per 1M tokens (USD)
pub input_price_per_mtok: Option<f64>,
/// Output token price per 1M tokens (USD)
pub output_price_per_mtok: Option<f64>,
}
/// Trait for providers that support status/info queries
#[async_trait]
pub trait ProviderInfo {
/// Get the current connection status
async fn status(&self) -> Result<ProviderStatus, LlmError>;
/// Get account information (if authenticated)
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
/// Get usage statistics (if available)
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
/// List available models
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
/// Check if a specific model is available
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = self.list_models().await?;
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
// ============================================================================
// Provider Factory
// ============================================================================
/// Supported LLM providers
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ProviderType {
Ollama,
Anthropic,
OpenAI,
}
impl ProviderType {
pub fn from_str(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"ollama" => Some(Self::Ollama),
"anthropic" | "claude" => Some(Self::Anthropic),
"openai" | "gpt" => Some(Self::OpenAI),
_ => None,
}
}
pub fn as_str(&self) -> &'static str {
match self {
Self::Ollama => "ollama",
Self::Anthropic => "anthropic",
Self::OpenAI => "openai",
}
}
/// Default model for this provider
pub fn default_model(&self) -> &'static str {
match self {
Self::Ollama => "qwen3:8b",
Self::Anthropic => "claude-sonnet-4-20250514",
Self::OpenAI => "gpt-4o",
}
}
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}

View File

@@ -0,0 +1,386 @@
//! Error recovery and retry logic for LLM operations
//!
//! This module provides configurable retry strategies with exponential backoff
//! for handling transient failures when communicating with LLM providers.
use crate::LlmError;
use rand::Rng;
use std::time::Duration;
/// Configuration for retry behavior
#[derive(Debug, Clone)]
pub struct RetryConfig {
/// Maximum number of retry attempts
pub max_retries: u32,
/// Initial delay before first retry (in milliseconds)
pub initial_delay_ms: u64,
/// Maximum delay between retries (in milliseconds)
pub max_delay_ms: u64,
/// Multiplier for exponential backoff
pub backoff_multiplier: f32,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay_ms: 1000,
max_delay_ms: 30000,
backoff_multiplier: 2.0,
}
}
}
impl RetryConfig {
/// Create a new retry configuration with custom values
pub fn new(
max_retries: u32,
initial_delay_ms: u64,
max_delay_ms: u64,
backoff_multiplier: f32,
) -> Self {
Self {
max_retries,
initial_delay_ms,
max_delay_ms,
backoff_multiplier,
}
}
/// Create a configuration with no retries
pub fn no_retry() -> Self {
Self {
max_retries: 0,
initial_delay_ms: 0,
max_delay_ms: 0,
backoff_multiplier: 1.0,
}
}
/// Create a configuration with aggressive retries for rate-limited scenarios
pub fn aggressive() -> Self {
Self {
max_retries: 5,
initial_delay_ms: 2000,
max_delay_ms: 60000,
backoff_multiplier: 2.5,
}
}
}
/// Determines whether an error is retryable
///
/// # Arguments
/// * `error` - The error to check
///
/// # Returns
/// `true` if the error is transient and the operation should be retried,
/// `false` if the error is permanent and retrying won't help
pub fn is_retryable_error(error: &LlmError) -> bool {
match error {
// Always retry rate limits
LlmError::RateLimit { .. } => true,
// Always retry timeouts
LlmError::Timeout(_) => true,
// Retry HTTP errors that are server-side (5xx)
LlmError::Http(msg) => {
// Check if the error message contains a 5xx status code
msg.contains("500")
|| msg.contains("502")
|| msg.contains("503")
|| msg.contains("504")
|| msg.contains("Internal Server Error")
|| msg.contains("Bad Gateway")
|| msg.contains("Service Unavailable")
|| msg.contains("Gateway Timeout")
}
// Don't retry authentication errors - they need user intervention
LlmError::Auth(_) => false,
// Don't retry JSON parsing errors - the data is malformed
LlmError::Json(_) => false,
// Don't retry API errors - these are typically client-side issues
LlmError::Api { .. } => false,
// Provider errors might be transient, but we conservatively don't retry
LlmError::Provider(_) => false,
// Stream errors are typically not retryable
LlmError::Stream(_) => false,
}
}
/// Strategy for retrying failed operations with exponential backoff
#[derive(Debug, Clone)]
pub struct RetryStrategy {
config: RetryConfig,
}
impl RetryStrategy {
/// Create a new retry strategy with the given configuration
pub fn new(config: RetryConfig) -> Self {
Self { config }
}
/// Create a retry strategy with default configuration
pub fn default_config() -> Self {
Self::new(RetryConfig::default())
}
/// Execute an async operation with retries
///
/// # Arguments
/// * `operation` - A function that returns a Future producing a Result
///
/// # Returns
/// The result of the operation, or the last error if all retries fail
///
/// # Example
/// ```ignore
/// let strategy = RetryStrategy::default_config();
/// let result = strategy.execute(|| async {
/// // Your LLM API call here
/// llm_client.chat(&messages, &options, None).await
/// }).await?;
/// ```
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
where
F: Fn() -> Fut,
Fut: std::future::Future<Output = Result<T, LlmError>>,
{
let mut attempt = 0;
loop {
// Try the operation
match operation().await {
Ok(result) => return Ok(result),
Err(err) => {
// Check if we should retry
if !is_retryable_error(&err) {
return Err(err);
}
attempt += 1;
// Check if we've exhausted retries
if attempt > self.config.max_retries {
return Err(err);
}
// Calculate delay with exponential backoff and jitter
let delay = self.delay_for_attempt(attempt);
// Log retry attempt (in a real implementation, you might use tracing)
eprintln!(
"Retry attempt {}/{} after {:?}",
attempt, self.config.max_retries, delay
);
// Sleep before next attempt
tokio::time::sleep(delay).await;
}
}
}
}
/// Calculate the delay for a given attempt number with jitter
///
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
/// Adds random jitter of ±10% to prevent thundering herd problems
///
/// # Arguments
/// * `attempt` - The attempt number (1-indexed)
///
/// # Returns
/// The delay duration to wait before the next retry
fn delay_for_attempt(&self, attempt: u32) -> Duration {
// Calculate base delay with exponential backoff
let base_delay_ms = self.config.initial_delay_ms as f64
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
// Cap at max_delay_ms
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
// Add jitter: ±10%
let mut rng = rand::thread_rng();
let jitter_factor = rng.gen_range(0.9..=1.1);
let final_delay_ms = capped_delay_ms * jitter_factor;
Duration::from_millis(final_delay_ms as u64)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
#[test]
fn test_default_retry_config() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay_ms, 1000);
assert_eq!(config.max_delay_ms, 30000);
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_no_retry_config() {
let config = RetryConfig::no_retry();
assert_eq!(config.max_retries, 0);
}
#[test]
fn test_is_retryable_error() {
// Retryable errors
assert!(is_retryable_error(&LlmError::RateLimit {
retry_after_secs: Some(60)
}));
assert!(is_retryable_error(&LlmError::Timeout(
"Request timed out".to_string()
)));
assert!(is_retryable_error(&LlmError::Http(
"500 Internal Server Error".to_string()
)));
assert!(is_retryable_error(&LlmError::Http(
"503 Service Unavailable".to_string()
)));
// Non-retryable errors
assert!(!is_retryable_error(&LlmError::Auth(
"Invalid API key".to_string()
)));
assert!(!is_retryable_error(&LlmError::Json(
"Invalid JSON".to_string()
)));
assert!(!is_retryable_error(&LlmError::Api {
message: "Invalid request".to_string(),
code: Some("400".to_string())
}));
assert!(!is_retryable_error(&LlmError::Http(
"400 Bad Request".to_string()
)));
}
#[test]
fn test_delay_calculation() {
let config = RetryConfig::default();
let strategy = RetryStrategy::new(config);
// Test that delays increase exponentially
let delay1 = strategy.delay_for_attempt(1);
let delay2 = strategy.delay_for_attempt(2);
let delay3 = strategy.delay_for_attempt(3);
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
}
#[test]
fn test_delay_max_cap() {
let config = RetryConfig {
max_retries: 10,
initial_delay_ms: 1000,
max_delay_ms: 5000,
backoff_multiplier: 2.0,
};
let strategy = RetryStrategy::new(config);
// Even with high attempt numbers, delay should be capped
let delay = strategy.delay_for_attempt(10);
assert!(delay.as_millis() <= 5500); // max + jitter
}
#[tokio::test]
async fn test_retry_success_on_first_attempt() {
let strategy = RetryStrategy::default_config();
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Ok::<_, LlmError>(42)
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_retry_success_after_retries() {
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
let strategy = RetryStrategy::new(config);
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
if current < 3 {
Err(LlmError::Timeout("Timeout".to_string()))
} else {
Ok(42)
}
}
})
.await;
assert_eq!(result.unwrap(), 42);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_exhausted() {
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
let strategy = RetryStrategy::new(config);
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
}
#[tokio::test]
async fn test_non_retryable_error() {
let strategy = RetryStrategy::default_config();
let call_count = Arc::new(AtomicU32::new(0));
let count_clone = call_count.clone();
let result = strategy
.execute(|| {
let count = count_clone.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
}
})
.await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
}
}

View File

@@ -0,0 +1,607 @@
//! Token counting utilities for LLM context management
//!
//! This module provides token counting abstractions and implementations for
//! managing LLM context windows. Token counters estimate token usage without
//! requiring external tokenization libraries, using heuristic-based approaches.
use crate::ChatMessage;
// ============================================================================
// TokenCounter Trait
// ============================================================================
/// Trait for counting tokens in text and chat messages
///
/// Implementations provide model-specific token counting logic to help
/// manage context windows and estimate API costs.
pub trait TokenCounter: Send + Sync {
/// Count tokens in a string
///
/// # Arguments
/// * `text` - The text to count tokens for
///
/// # Returns
/// Estimated number of tokens
fn count(&self, text: &str) -> usize;
/// Count tokens in chat messages
///
/// This accounts for both the message content and the overhead
/// from the chat message structure (roles, delimiters, etc.).
///
/// # Arguments
/// * `messages` - The messages to count tokens for
///
/// # Returns
/// Estimated total tokens including message structure overhead
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
/// Get the model's max context window size
///
/// # Returns
/// Maximum number of tokens the model can handle
fn max_context(&self) -> usize;
}
// ============================================================================
// SimpleTokenCounter
// ============================================================================
/// A basic token counter using simple heuristics
///
/// This counter uses the rule of thumb that English text averages about
/// 4 characters per token. It adds overhead for message structure.
///
/// # Example
/// ```
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = SimpleTokenCounter::new(8192);
/// let text = "Hello, world!";
/// let tokens = counter.count(text);
/// assert!(tokens > 0);
///
/// let messages = vec![
/// ChatMessage::user("What is the weather?"),
/// ChatMessage::assistant("I don't have access to weather data."),
/// ];
/// let total = counter.count_messages(&messages);
/// assert!(total > 0);
/// ```
#[derive(Debug, Clone)]
pub struct SimpleTokenCounter {
max_context: usize,
}
impl SimpleTokenCounter {
/// Create a new simple token counter
///
/// # Arguments
/// * `max_context` - Maximum context window size for the model
pub fn new(max_context: usize) -> Self {
Self { max_context }
}
/// Create a token counter with a default 8192 token context
pub fn default_8k() -> Self {
Self::new(8192)
}
/// Create a token counter with a 32k token context
pub fn with_32k() -> Self {
Self::new(32768)
}
/// Create a token counter with a 128k token context
pub fn with_128k() -> Self {
Self::new(131072)
}
}
impl TokenCounter for SimpleTokenCounter {
fn count(&self, text: &str) -> usize {
// Estimate: approximately 4 characters per token for English
// Add 3 before dividing to round up
(text.len() + 3) / 4
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
let mut total = 0;
// Base overhead for message formatting (estimated)
// Each message has role, delimiters, etc.
const MESSAGE_OVERHEAD: usize = 4;
for msg in messages {
// Count role
total += MESSAGE_OVERHEAD;
// Count content
if let Some(content) = &msg.content {
total += self.count(content);
}
// Count tool calls (more expensive due to JSON structure)
if let Some(tool_calls) = &msg.tool_calls {
for tc in tool_calls {
// ID overhead
total += self.count(&tc.id);
// Function name
total += self.count(&tc.function.name);
// Arguments (JSON serialized, add 20% overhead for JSON structure)
let args_str = tc.function.arguments.to_string();
total += (self.count(&args_str) * 12) / 10;
}
}
// Count tool call id for tool result messages
if let Some(tool_call_id) = &msg.tool_call_id {
total += self.count(tool_call_id);
}
// Count tool name for tool result messages
if let Some(name) = &msg.name {
total += self.count(name);
}
}
total
}
fn max_context(&self) -> usize {
self.max_context
}
}
// ============================================================================
// ClaudeTokenCounter
// ============================================================================
/// Token counter optimized for Anthropic Claude models
///
/// Claude models have specific tokenization characteristics and overhead.
/// This counter adjusts the estimates accordingly.
///
/// # Example
/// ```
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = ClaudeTokenCounter::new();
/// let messages = vec![
/// ChatMessage::system("You are a helpful assistant."),
/// ChatMessage::user("Hello!"),
/// ];
/// let total = counter.count_messages(&messages);
/// ```
#[derive(Debug, Clone)]
pub struct ClaudeTokenCounter {
max_context: usize,
}
impl ClaudeTokenCounter {
/// Create a new Claude token counter with default 200k context
///
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
pub fn new() -> Self {
Self {
max_context: 200_000,
}
}
/// Create a Claude counter with a custom context window
///
/// # Arguments
/// * `max_context` - Maximum context window size
pub fn with_context(max_context: usize) -> Self {
Self { max_context }
}
/// Create a counter for Claude 3 Haiku (200k context)
pub fn haiku() -> Self {
Self::new()
}
/// Create a counter for Claude 3.5 Sonnet (200k context)
pub fn sonnet() -> Self {
Self::new()
}
/// Create a counter for Claude 4 Opus (200k context)
pub fn opus() -> Self {
Self::new()
}
}
impl Default for ClaudeTokenCounter {
fn default() -> Self {
Self::new()
}
}
impl TokenCounter for ClaudeTokenCounter {
fn count(&self, text: &str) -> usize {
// Claude's tokenization is similar to the 4 chars/token heuristic
// but tends to be slightly more efficient with structured content
(text.len() + 3) / 4
}
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
let mut total = 0;
// Claude has specific message formatting overhead
const MESSAGE_OVERHEAD: usize = 5;
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
for msg in messages {
// Different overhead for system vs other messages
let overhead = if matches!(msg.role, crate::Role::System) {
SYSTEM_MESSAGE_OVERHEAD
} else {
MESSAGE_OVERHEAD
};
total += overhead;
// Count content
if let Some(content) = &msg.content {
total += self.count(content);
}
// Count tool calls
if let Some(tool_calls) = &msg.tool_calls {
// Claude's tool call format has additional overhead
const TOOL_CALL_OVERHEAD: usize = 10;
for tc in tool_calls {
total += TOOL_CALL_OVERHEAD;
total += self.count(&tc.id);
total += self.count(&tc.function.name);
// Arguments with JSON structure overhead
let args_str = tc.function.arguments.to_string();
total += (self.count(&args_str) * 12) / 10;
}
}
// Tool result overhead
if msg.tool_call_id.is_some() {
const TOOL_RESULT_OVERHEAD: usize = 8;
total += TOOL_RESULT_OVERHEAD;
if let Some(tool_call_id) = &msg.tool_call_id {
total += self.count(tool_call_id);
}
if let Some(name) = &msg.name {
total += self.count(name);
}
}
}
total
}
fn max_context(&self) -> usize {
self.max_context
}
}
// ============================================================================
// ContextWindow
// ============================================================================
/// Manages context window tracking for a conversation
///
/// Helps monitor token usage and determine when context limits are approaching.
///
/// # Example
/// ```
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
/// use llm_core::ChatMessage;
///
/// let counter = SimpleTokenCounter::new(8192);
/// let mut window = ContextWindow::new(counter.max_context());
///
/// let messages = vec![
/// ChatMessage::user("Hello!"),
/// ChatMessage::assistant("Hi there!"),
/// ];
///
/// let tokens = counter.count_messages(&messages);
/// window.add_tokens(tokens);
///
/// println!("Used: {} tokens", window.used());
/// println!("Remaining: {} tokens", window.remaining());
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
///
/// if window.is_near_limit(0.8) {
/// println!("Warning: Context is 80% full!");
/// }
/// ```
#[derive(Debug, Clone)]
pub struct ContextWindow {
/// Number of tokens currently used
used: usize,
/// Maximum number of tokens allowed
max: usize,
}
impl ContextWindow {
/// Create a new context window tracker
///
/// # Arguments
/// * `max` - Maximum context window size in tokens
pub fn new(max: usize) -> Self {
Self { used: 0, max }
}
/// Create a context window with initial usage
///
/// # Arguments
/// * `max` - Maximum context window size
/// * `used` - Initial number of tokens used
pub fn with_usage(max: usize, used: usize) -> Self {
Self { used, max }
}
/// Get the number of tokens currently used
pub fn used(&self) -> usize {
self.used
}
/// Get the maximum number of tokens
pub fn max(&self) -> usize {
self.max
}
/// Get the number of remaining tokens
pub fn remaining(&self) -> usize {
self.max.saturating_sub(self.used)
}
/// Get the usage as a percentage (0.0 to 1.0)
///
/// Returns the fraction of the context window that is currently used.
pub fn usage_percent(&self) -> f32 {
if self.max == 0 {
return 0.0;
}
self.used as f32 / self.max as f32
}
/// Check if usage is near the limit
///
/// # Arguments
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
/// 0.8 means "is usage > 80%?"
///
/// # Returns
/// `true` if the current usage exceeds the threshold percentage
pub fn is_near_limit(&self, threshold: f32) -> bool {
self.usage_percent() > threshold
}
/// Add tokens to the usage count
///
/// # Arguments
/// * `tokens` - Number of tokens to add
pub fn add_tokens(&mut self, tokens: usize) {
self.used = self.used.saturating_add(tokens);
}
/// Set the current usage
///
/// # Arguments
/// * `used` - Number of tokens currently used
pub fn set_used(&mut self, used: usize) {
self.used = used;
}
/// Reset the usage counter to zero
pub fn reset(&mut self) {
self.used = 0;
}
/// Check if there's enough room for additional tokens
///
/// # Arguments
/// * `tokens` - Number of tokens needed
///
/// # Returns
/// `true` if adding these tokens would stay within the limit
pub fn has_room_for(&self, tokens: usize) -> bool {
self.used.saturating_add(tokens) <= self.max
}
/// Get a visual progress bar representation
///
/// # Arguments
/// * `width` - Width of the progress bar in characters
///
/// # Returns
/// A string with a simple text-based progress bar
pub fn progress_bar(&self, width: usize) -> String {
if width == 0 {
return String::new();
}
let percent = self.usage_percent();
let filled = ((percent * width as f32) as usize).min(width);
let empty = width - filled;
format!(
"[{}{}] {:.1}%",
"=".repeat(filled),
" ".repeat(empty),
percent * 100.0
)
}
}
// ============================================================================
// Tests
// ============================================================================
#[cfg(test)]
mod tests {
use super::*;
use crate::{ChatMessage, FunctionCall, ToolCall};
use serde_json::json;
#[test]
fn test_simple_counter_basic() {
let counter = SimpleTokenCounter::new(8192);
// Empty string
assert_eq!(counter.count(""), 0);
// Short string (~4 chars/token)
let text = "Hello, world!"; // 13 chars -> ~4 tokens
let count = counter.count(text);
assert!(count >= 3 && count <= 5);
// Longer text
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
let count = counter.count(text);
assert!(count >= 10 && count <= 13);
}
#[test]
fn test_simple_counter_messages() {
let counter = SimpleTokenCounter::new(8192);
let messages = vec![
ChatMessage::user("Hello!"),
ChatMessage::assistant("Hi there! How can I help you today?"),
];
let total = counter.count_messages(&messages);
// Should be more than just the text due to overhead
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
assert!(total > text_only);
}
#[test]
fn test_simple_counter_with_tool_calls() {
let counter = SimpleTokenCounter::new(8192);
let tool_call = ToolCall {
id: "call_123".to_string(),
call_type: "function".to_string(),
function: FunctionCall {
name: "read_file".to_string(),
arguments: json!({"path": "/etc/hosts"}),
},
};
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
let total = counter.count_messages(&messages);
assert!(total > 0);
}
#[test]
fn test_claude_counter() {
let counter = ClaudeTokenCounter::new();
assert_eq!(counter.max_context(), 200_000);
let text = "Hello, Claude!";
let count = counter.count(text);
assert!(count > 0);
}
#[test]
fn test_claude_counter_system_message() {
let counter = ClaudeTokenCounter::new();
let messages = vec![
ChatMessage::system("You are a helpful assistant."),
ChatMessage::user("Hello!"),
];
let total = counter.count_messages(&messages);
assert!(total > 0);
}
#[test]
fn test_context_window() {
let mut window = ContextWindow::new(1000);
assert_eq!(window.used(), 0);
assert_eq!(window.max(), 1000);
assert_eq!(window.remaining(), 1000);
assert_eq!(window.usage_percent(), 0.0);
window.add_tokens(200);
assert_eq!(window.used(), 200);
assert_eq!(window.remaining(), 800);
assert_eq!(window.usage_percent(), 0.2);
window.add_tokens(600);
assert_eq!(window.used(), 800);
assert!(window.is_near_limit(0.7));
assert!(!window.is_near_limit(0.9));
assert!(window.has_room_for(200));
assert!(!window.has_room_for(300));
window.reset();
assert_eq!(window.used(), 0);
}
#[test]
fn test_context_window_progress_bar() {
let mut window = ContextWindow::new(100);
window.add_tokens(50);
let bar = window.progress_bar(10);
assert!(bar.contains("====="));
assert!(bar.contains("50.0%"));
window.add_tokens(40);
let bar = window.progress_bar(10);
assert!(bar.contains("========="));
assert!(bar.contains("90.0%"));
}
#[test]
fn test_context_window_saturation() {
let mut window = ContextWindow::new(100);
// Adding more tokens than max should saturate, not overflow
window.add_tokens(150);
assert_eq!(window.used(), 150);
assert_eq!(window.remaining(), 0);
}
#[test]
fn test_simple_counter_constructors() {
let counter1 = SimpleTokenCounter::default_8k();
assert_eq!(counter1.max_context(), 8192);
let counter2 = SimpleTokenCounter::with_32k();
assert_eq!(counter2.max_context(), 32768);
let counter3 = SimpleTokenCounter::with_128k();
assert_eq!(counter3.max_context(), 131072);
}
#[test]
fn test_claude_counter_variants() {
let haiku = ClaudeTokenCounter::haiku();
assert_eq!(haiku.max_context(), 200_000);
let sonnet = ClaudeTokenCounter::sonnet();
assert_eq!(sonnet.max_context(), 200_000);
let opus = ClaudeTokenCounter::opus();
assert_eq!(opus.max_context(), 200_000);
let custom = ClaudeTokenCounter::with_context(100_000);
assert_eq!(custom.max_context(), 100_000);
}
}

View File

@@ -6,11 +6,13 @@ license.workspace = true
rust-version.workspace = true
[dependencies]
llm-core = { path = "../core" }
reqwest = { version = "0.12", features = ["json", "stream"] }
tokio = { version = "1.39", features = ["rt-multi-thread"] }
tokio = { version = "1.39", features = ["rt-multi-thread", "macros"] }
futures = "0.3"
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "1"
bytes = "1"
tokio-stream = "0.1.17"
async-trait = "0.1"

View File

@@ -1,14 +1,20 @@
use crate::types::{ChatMessage, ChatResponseChunk, Tool};
use futures::{Stream, TryStreamExt};
use futures::{Stream, StreamExt, TryStreamExt};
use reqwest::Client;
use serde::Serialize;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use async_trait::async_trait;
use llm_core::{
LlmProvider, ProviderInfo, LlmError, ChatOptions, ChunkStream,
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
};
#[derive(Debug, Clone)]
pub struct OllamaClient {
http: Client,
base_url: String, // e.g. "http://localhost:11434"
api_key: Option<String>, // For Ollama Cloud authentication
current_model: String, // Default model for this client
}
#[derive(Debug, Clone, Default)]
@@ -27,12 +33,24 @@ pub enum OllamaError {
Protocol(String),
}
// Convert OllamaError to LlmError
impl From<OllamaError> for LlmError {
fn from(err: OllamaError) -> Self {
match err {
OllamaError::Http(e) => LlmError::Http(e.to_string()),
OllamaError::Json(e) => LlmError::Json(e.to_string()),
OllamaError::Protocol(msg) => LlmError::Provider(msg),
}
}
}
impl OllamaClient {
pub fn new(base_url: impl Into<String>) -> Self {
Self {
http: Client::new(),
base_url: base_url.into().trim_end_matches('/').to_string(),
api_key: None,
current_model: "qwen3:8b".to_string(),
}
}
@@ -41,12 +59,17 @@ impl OllamaClient {
self
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.current_model = model.into();
self
}
pub fn with_cloud() -> Self {
// Same API, different base
Self::new("https://ollama.com")
}
pub async fn chat_stream(
pub async fn chat_stream_raw(
&self,
messages: &[ChatMessage],
opts: &OllamaOptions,
@@ -99,3 +122,208 @@ impl OllamaClient {
Ok(out)
}
}
// ============================================================================
// LlmProvider Trait Implementation
// ============================================================================
#[async_trait]
impl LlmProvider for OllamaClient {
fn name(&self) -> &str {
"ollama"
}
fn model(&self) -> &str {
&self.current_model
}
async fn chat_stream(
&self,
messages: &[llm_core::ChatMessage],
options: &ChatOptions,
tools: Option<&[llm_core::Tool]>,
) -> Result<ChunkStream, LlmError> {
// Convert llm_core messages to Ollama messages
let ollama_messages: Vec<ChatMessage> = messages.iter().map(|m| m.into()).collect();
// Convert llm_core tools to Ollama tools if present
let ollama_tools: Option<Vec<Tool>> = tools.map(|tools| {
tools.iter().map(|t| Tool {
tool_type: t.tool_type.clone(),
function: crate::types::ToolFunction {
name: t.function.name.clone(),
description: t.function.description.clone(),
parameters: crate::types::ToolParameters {
param_type: t.function.parameters.param_type.clone(),
properties: t.function.parameters.properties.clone(),
required: t.function.parameters.required.clone(),
},
},
}).collect()
});
let opts = OllamaOptions {
model: options.model.clone(),
stream: true,
};
// Make the request and build the body inline to avoid lifetime issues
#[derive(Serialize)]
struct Body<'a> {
model: &'a str,
messages: &'a [ChatMessage],
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<&'a [Tool]>,
}
let url = format!("{}/api/chat", self.base_url);
let body = Body {
model: &opts.model,
messages: &ollama_messages,
stream: true,
tools: ollama_tools.as_deref(),
};
let mut req = self.http.post(url).json(&body);
// Add Authorization header if API key is present
if let Some(ref key) = self.api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await
.map_err(|e| LlmError::Http(e.to_string()))?;
let bytes_stream = resp.bytes_stream();
// NDJSON parser: split by '\n', parse each as JSON and stream the results
let converted_stream = bytes_stream
.map(|result| {
result.map_err(|e| LlmError::Http(e.to_string()))
})
.map_ok(|bytes| {
// Convert the chunk to a UTF-8 string and own it
let txt = String::from_utf8_lossy(&bytes).into_owned();
// Parse each non-empty line into a ChatResponseChunk
let results: Vec<Result<llm_core::StreamChunk, LlmError>> = txt
.lines()
.filter_map(|line| {
let trimmed = line.trim();
if trimmed.is_empty() {
None
} else {
Some(
serde_json::from_str::<ChatResponseChunk>(trimmed)
.map(|chunk| llm_core::StreamChunk::from(chunk))
.map_err(|e| LlmError::Json(e.to_string())),
)
}
})
.collect();
futures::stream::iter(results)
})
.try_flatten();
Ok(Box::pin(converted_stream))
}
}
// ============================================================================
// ProviderInfo Trait Implementation
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
struct OllamaModelList {
models: Vec<OllamaModel>,
}
#[derive(Debug, Clone, Deserialize)]
struct OllamaModel {
name: String,
#[serde(default)]
modified_at: Option<String>,
#[serde(default)]
size: Option<u64>,
#[serde(default)]
digest: Option<String>,
#[serde(default)]
details: Option<OllamaModelDetails>,
}
#[derive(Debug, Clone, Deserialize)]
struct OllamaModelDetails {
#[serde(default)]
format: Option<String>,
#[serde(default)]
family: Option<String>,
#[serde(default)]
parameter_size: Option<String>,
}
#[async_trait]
impl ProviderInfo for OllamaClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
// Try to ping the Ollama server
let url = format!("{}/api/tags", self.base_url);
let reachable = self.http.get(&url).send().await.is_ok();
Ok(ProviderStatus {
provider: "ollama".to_string(),
authenticated: self.api_key.is_some(),
account: None, // Ollama is local, no account info
model: self.current_model.clone(),
endpoint: self.base_url.clone(),
reachable,
message: if reachable {
Some("Connected to Ollama".to_string())
} else {
Some("Cannot reach Ollama server".to_string())
},
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// Ollama is a local service, no account info
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// Ollama doesn't track usage statistics
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
let url = format!("{}/api/tags", self.base_url);
let mut req = self.http.get(&url);
// Add Authorization header if API key is present
if let Some(ref key) = self.api_key {
req = req.header("Authorization", format!("Bearer {}", key));
}
let resp = req.send().await
.map_err(|e| LlmError::Http(e.to_string()))?;
let model_list: OllamaModelList = resp.json().await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Convert Ollama models to ModelInfo
let models = model_list.models.into_iter().map(|m| {
ModelInfo {
id: m.name.clone(),
display_name: Some(m.name.clone()),
description: m.details.as_ref()
.and_then(|d| d.family.as_ref())
.map(|f| format!("{} model", f)),
context_window: None, // Ollama doesn't provide this in list
max_output_tokens: None,
supports_tools: true, // Most Ollama models support tools
supports_vision: false, // Would need to check model capabilities
input_price_per_mtok: None, // Local models are free
output_price_per_mtok: None,
}
}).collect();
Ok(models)
}
}

View File

@@ -1,5 +1,13 @@
pub mod client;
pub mod types;
pub use client::{OllamaClient, OllamaOptions};
pub use client::{OllamaClient, OllamaOptions, OllamaError};
pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall};
// Re-export llm-core traits and types for convenience
pub use llm_core::{
LlmProvider, ProviderInfo, LlmError,
ChatOptions, StreamChunk, ToolCallDelta, Usage,
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
Role,
};

View File

@@ -1,5 +1,6 @@
use serde::{Deserialize, Serialize};
use serde_json::Value;
use llm_core::{StreamChunk, ToolCallDelta};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
@@ -63,3 +64,67 @@ pub struct ChunkMessage {
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<ToolCall>>,
}
// ============================================================================
// Conversions to/from llm-core types
// ============================================================================
/// Convert from llm_core::ChatMessage to Ollama's ChatMessage
impl From<&llm_core::ChatMessage> for ChatMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
let role = msg.role.as_str().to_string();
// Convert tool_calls if present
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
calls.iter().map(|tc| ToolCall {
id: Some(tc.id.clone()),
call_type: Some(tc.call_type.clone()),
function: FunctionCall {
name: tc.function.name.clone(),
arguments: tc.function.arguments.clone(),
},
}).collect()
});
ChatMessage {
role,
content: msg.content.clone(),
tool_calls,
}
}
}
/// Convert from Ollama's ChatResponseChunk to llm_core::StreamChunk
impl From<ChatResponseChunk> for StreamChunk {
fn from(chunk: ChatResponseChunk) -> Self {
let done = chunk.done.unwrap_or(false);
let content = chunk.message.as_ref().and_then(|m| m.content.clone());
// Convert tool calls to deltas
let tool_calls = chunk.message.as_ref().and_then(|m| {
m.tool_calls.as_ref().map(|calls| {
calls.iter().enumerate().map(|(index, tc)| {
// Serialize arguments back to JSON string for delta
let arguments_delta = serde_json::to_string(&tc.function.arguments).ok();
ToolCallDelta {
index,
id: tc.id.clone(),
function_name: Some(tc.function.name.clone()),
arguments_delta,
}
}).collect()
})
});
// Ollama doesn't provide per-chunk usage stats, only in final chunk
let usage = None;
StreamChunk {
content,
tool_calls,
done,
usage,
}
}
}

View File

@@ -0,0 +1,18 @@
[package]
name = "llm-openai"
version = "0.1.0"
edition.workspace = true
license.workspace = true
description = "OpenAI GPT API client for Owlen"
[dependencies]
llm-core = { path = "../core" }
async-trait = "0.1"
futures = "0.3"
reqwest = { version = "0.12", features = ["json", "stream"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
tokio = { version = "1", features = ["sync", "time", "io-util"] }
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
tokio-util = { version = "0.7", features = ["codec", "io"] }
tracing = "0.1"

View File

@@ -0,0 +1,285 @@
//! OpenAI OAuth Authentication
//!
//! Implements device code flow for authenticating with OpenAI without API keys.
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// OAuth client for OpenAI device flow
pub struct OpenAIAuth {
http: Client,
client_id: String,
}
// OpenAI OAuth endpoints
const AUTH_BASE_URL: &str = "https://auth.openai.com";
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
const TOKEN_ENDPOINT: &str = "/oauth/token";
// Default client ID for Owlen CLI
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
impl OpenAIAuth {
/// Create a new OAuth client with the default CLI client ID
pub fn new() -> Self {
Self {
http: Client::new(),
client_id: DEFAULT_CLIENT_ID.to_string(),
}
}
/// Create with a custom client ID
pub fn with_client_id(client_id: impl Into<String>) -> Self {
Self {
http: Client::new(),
client_id: client_id.into(),
}
}
}
impl Default for OpenAIAuth {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Serialize)]
struct DeviceCodeRequest<'a> {
client_id: &'a str,
scope: &'a str,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeApiResponse {
device_code: String,
user_code: String,
verification_uri: String,
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
#[derive(Debug, Serialize)]
struct TokenRequest<'a> {
client_id: &'a str,
device_code: &'a str,
grant_type: &'a str,
}
#[derive(Debug, Deserialize)]
struct TokenApiResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: Option<u64>,
refresh_token: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenErrorResponse {
error: String,
error_description: Option<String>,
}
#[async_trait::async_trait]
impl OAuthProvider for OpenAIAuth {
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
let request = DeviceCodeRequest {
client_id: &self.client_id,
scope: "api.read api.write",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!(
"Device code request failed ({}): {}",
status, text
)));
}
let api_response: DeviceCodeApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
Ok(DeviceCodeResponse {
device_code: api_response.device_code,
user_code: api_response.user_code,
verification_uri: api_response.verification_uri,
verification_uri_complete: api_response.verification_uri_complete,
expires_in: api_response.expires_in,
interval: api_response.interval,
})
}
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
let request = TokenRequest {
client_id: &self.client_id,
device_code,
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if response.status().is_success() {
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
return Ok(DeviceAuthResult::Success {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_in: token_response.expires_in,
});
}
// Parse error response
let error_response: TokenErrorResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
match error_response.error.as_str() {
"authorization_pending" => Ok(DeviceAuthResult::Pending),
"slow_down" => Ok(DeviceAuthResult::Pending),
"access_denied" => Ok(DeviceAuthResult::Denied),
"expired_token" => Ok(DeviceAuthResult::Expired),
_ => Err(LlmError::Auth(format!(
"Token request failed: {} - {}",
error_response.error,
error_response.error_description.unwrap_or_default()
))),
}
}
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
#[derive(Serialize)]
struct RefreshRequest<'a> {
client_id: &'a str,
refresh_token: &'a str,
grant_type: &'a str,
}
let request = RefreshRequest {
client_id: &self.client_id,
refresh_token,
grant_type: "refresh_token",
};
let response = self
.http
.post(&url)
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
}
let token_response: TokenApiResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
let expires_at = token_response.expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
Ok(AuthMethod::OAuth {
access_token: token_response.access_token,
refresh_token: token_response.refresh_token,
expires_at,
})
}
}
/// Helper to perform the full device auth flow with polling
pub async fn perform_device_auth<F>(
auth: &OpenAIAuth,
on_code: F,
) -> Result<AuthMethod, LlmError>
where
F: FnOnce(&DeviceCodeResponse),
{
// Start the device flow
let device_code = auth.start_device_auth().await?;
// Let caller display the code to user
on_code(&device_code);
// Poll for completion
let poll_interval = std::time::Duration::from_secs(device_code.interval);
let deadline =
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
loop {
if std::time::Instant::now() > deadline {
return Err(LlmError::Auth("Device code expired".to_string()));
}
tokio::time::sleep(poll_interval).await;
match auth.poll_device_auth(&device_code.device_code).await? {
DeviceAuthResult::Success {
access_token,
refresh_token,
expires_in,
} => {
let expires_at = expires_in.map(|secs| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs() + secs)
.unwrap_or(0)
});
return Ok(AuthMethod::OAuth {
access_token,
refresh_token,
expires_at,
});
}
DeviceAuthResult::Pending => continue,
DeviceAuthResult::Denied => {
return Err(LlmError::Auth("Authorization denied by user".to_string()));
}
DeviceAuthResult::Expired => {
return Err(LlmError::Auth("Device code expired".to_string()));
}
}
}
}

View File

@@ -0,0 +1,561 @@
//! OpenAI GPT API Client
//!
//! Implements the Chat Completions API with streaming support.
use crate::types::*;
use async_trait::async_trait;
use futures::StreamExt;
use llm_core::{
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
ToolCallDelta, Usage, UsageStats,
};
use reqwest::Client;
use tokio::io::AsyncBufReadExt;
use tokio_stream::wrappers::LinesStream;
use tokio_util::io::StreamReader;
const API_BASE_URL: &str = "https://api.openai.com/v1";
const CHAT_ENDPOINT: &str = "/chat/completions";
const MODELS_ENDPOINT: &str = "/models";
/// OpenAI GPT API client
pub struct OpenAIClient {
http: Client,
auth: AuthMethod,
model: String,
}
impl OpenAIClient {
/// Create a new client with API key authentication
pub fn new(api_key: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::api_key(api_key),
model: "gpt-4o".to_string(),
}
}
/// Create a new client with OAuth token
pub fn with_oauth(access_token: impl Into<String>) -> Self {
Self {
http: Client::new(),
auth: AuthMethod::oauth(access_token),
model: "gpt-4o".to_string(),
}
}
/// Create a new client with full AuthMethod
pub fn with_auth(auth: AuthMethod) -> Self {
Self {
http: Client::new(),
auth,
model: "gpt-4o".to_string(),
}
}
/// Set the model to use
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = model.into();
self
}
/// Get current auth method (for token refresh)
pub fn auth(&self) -> &AuthMethod {
&self.auth
}
/// Update the auth method (after refresh)
pub fn set_auth(&mut self, auth: AuthMethod) {
self.auth = auth;
}
/// Convert messages to OpenAI format
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
messages.iter().map(OpenAIMessage::from).collect()
}
/// Convert tools to OpenAI format
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
tools.map(|t| t.iter().map(OpenAITool::from).collect())
}
}
#[async_trait]
impl LlmProvider for OpenAIClient {
fn name(&self) -> &str {
"openai"
}
fn model(&self) -> &str {
&self.model
}
async fn chat_stream(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChunkStream, LlmError> {
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let openai_messages = Self::prepare_messages(messages);
let openai_tools = Self::prepare_tools(tools);
let request = ChatCompletionRequest {
model,
messages: openai_messages,
temperature: options.temperature,
max_tokens: options.max_tokens,
top_p: options.top_p,
stop: options.stop.as_deref(),
tools: openai_tools,
tool_choice: None,
stream: true,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", bearer))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
// Try to parse as error response
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
return Err(LlmError::Api {
message: err_resp.error.message,
code: err_resp.error.code,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
// Parse SSE stream
let byte_stream = response
.bytes_stream()
.map(|result| result.map_err(std::io::Error::other));
let reader = StreamReader::new(byte_stream);
let buf_reader = tokio::io::BufReader::new(reader);
let lines_stream = LinesStream::new(buf_reader.lines());
let chunk_stream = lines_stream.filter_map(|line_result| async move {
match line_result {
Ok(line) => parse_sse_line(&line),
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
}
});
Ok(Box::pin(chunk_stream))
}
async fn chat(
&self,
messages: &[ChatMessage],
options: &ChatOptions,
tools: Option<&[Tool]>,
) -> Result<ChatResponse, LlmError> {
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
let model = if options.model.is_empty() {
&self.model
} else {
&options.model
};
let openai_messages = Self::prepare_messages(messages);
let openai_tools = Self::prepare_tools(tools);
let request = ChatCompletionRequest {
model,
messages: openai_messages,
temperature: options.temperature,
max_tokens: options.max_tokens,
top_p: options.top_p,
stop: options.stop.as_deref(),
tools: openai_tools,
tool_choice: None,
stream: false,
};
let bearer = self
.auth
.bearer_token()
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
let response = self
.http
.post(&url)
.header("Authorization", format!("Bearer {}", bearer))
.json(&request)
.send()
.await
.map_err(|e| LlmError::Http(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let text = response
.text()
.await
.unwrap_or_else(|_| "Unknown error".to_string());
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
return Err(LlmError::RateLimit {
retry_after_secs: None,
});
}
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
return Err(LlmError::Api {
message: err_resp.error.message,
code: err_resp.error.code,
});
}
return Err(LlmError::Api {
message: text,
code: Some(status.to_string()),
});
}
let api_response: ChatCompletionResponse = response
.json()
.await
.map_err(|e| LlmError::Json(e.to_string()))?;
// Extract the first choice
let choice = api_response
.choices
.first()
.ok_or_else(|| LlmError::Api {
message: "No choices in response".to_string(),
code: None,
})?;
let content = choice.message.content.clone();
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|call| {
let arguments: serde_json::Value =
serde_json::from_str(&call.function.arguments).unwrap_or_default();
ToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: FunctionCall {
name: call.function.name.clone(),
arguments,
},
}
})
.collect()
});
let usage = api_response.usage.map(|u| Usage {
prompt_tokens: u.prompt_tokens,
completion_tokens: u.completion_tokens,
total_tokens: u.total_tokens,
});
Ok(ChatResponse {
content,
tool_calls,
usage,
})
}
}
/// Parse a single SSE line into a StreamChunk
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
let line = line.trim();
// Skip empty lines and comments
if line.is_empty() || line.starts_with(':') {
return None;
}
// SSE format: "data: <json>"
if let Some(data) = line.strip_prefix("data: ") {
// OpenAI sends [DONE] to signal end
if data == "[DONE]" {
return Some(Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
}));
}
// Parse the JSON chunk
match serde_json::from_str::<ChatCompletionChunk>(data) {
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
Err(e) => {
tracing::warn!("Failed to parse SSE chunk: {}", e);
None
}
}
} else {
None
}
}
/// Convert OpenAI chunk to our common format
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
let choice = chunk.choices.first();
if let Some(choice) = choice {
let content = choice.delta.content.clone();
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
deltas
.iter()
.map(|delta| ToolCallDelta {
index: delta.index,
id: delta.id.clone(),
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
})
.collect()
});
let done = choice.finish_reason.is_some();
Ok(StreamChunk {
content,
tool_calls,
done,
usage: None,
})
} else {
// No choices, treat as done
Ok(StreamChunk {
content: None,
tool_calls: None,
done: true,
usage: None,
})
}
}
// ============================================================================
// ProviderInfo Implementation
// ============================================================================
/// Known GPT models with their specifications
fn get_gpt_models() -> Vec<ModelInfo> {
vec![
ModelInfo {
id: "gpt-4o".to_string(),
display_name: Some("GPT-4o".to_string()),
description: Some("Most advanced multimodal model with vision".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(16_384),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(2.50),
output_price_per_mtok: Some(10.0),
},
ModelInfo {
id: "gpt-4o-mini".to_string(),
display_name: Some("GPT-4o mini".to_string()),
description: Some("Affordable and fast model for simple tasks".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(16_384),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(0.15),
output_price_per_mtok: Some(0.60),
},
ModelInfo {
id: "gpt-4-turbo".to_string(),
display_name: Some("GPT-4 Turbo".to_string()),
description: Some("Previous generation high-performance model".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(4_096),
supports_tools: true,
supports_vision: true,
input_price_per_mtok: Some(10.0),
output_price_per_mtok: Some(30.0),
},
ModelInfo {
id: "gpt-3.5-turbo".to_string(),
display_name: Some("GPT-3.5 Turbo".to_string()),
description: Some("Fast and affordable for simple tasks".to_string()),
context_window: Some(16_385),
max_output_tokens: Some(4_096),
supports_tools: true,
supports_vision: false,
input_price_per_mtok: Some(0.50),
output_price_per_mtok: Some(1.50),
},
ModelInfo {
id: "o1".to_string(),
display_name: Some("OpenAI o1".to_string()),
description: Some("Reasoning model optimized for complex problems".to_string()),
context_window: Some(200_000),
max_output_tokens: Some(100_000),
supports_tools: false,
supports_vision: true,
input_price_per_mtok: Some(15.0),
output_price_per_mtok: Some(60.0),
},
ModelInfo {
id: "o1-mini".to_string(),
display_name: Some("OpenAI o1-mini".to_string()),
description: Some("Faster reasoning model for STEM".to_string()),
context_window: Some(128_000),
max_output_tokens: Some(65_536),
supports_tools: false,
supports_vision: true,
input_price_per_mtok: Some(3.0),
output_price_per_mtok: Some(12.0),
},
]
}
#[async_trait]
impl ProviderInfo for OpenAIClient {
async fn status(&self) -> Result<ProviderStatus, LlmError> {
let authenticated = self.auth.bearer_token().is_some();
// Try to reach the API by listing models
let reachable = if authenticated {
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
let bearer = self.auth.bearer_token().unwrap();
match self
.http
.get(&url)
.header("Authorization", format!("Bearer {}", bearer))
.send()
.await
{
Ok(resp) => resp.status().is_success(),
Err(_) => false,
}
} else {
false
};
let message = if !authenticated {
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
} else if !reachable {
Some("Cannot reach OpenAI API".to_string())
} else {
Some("Connected".to_string())
};
Ok(ProviderStatus {
provider: "openai".to_string(),
authenticated,
account: None, // OpenAI doesn't expose account info via API
model: self.model.clone(),
endpoint: API_BASE_URL.to_string(),
reachable,
message,
})
}
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
// OpenAI doesn't have a public account info endpoint
Ok(None)
}
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
// OpenAI doesn't expose usage stats via the standard API
Ok(None)
}
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
// We can optionally fetch from API, but return known models for now
Ok(get_gpt_models())
}
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
let models = get_gpt_models();
Ok(models.into_iter().find(|m| m.id == model_id))
}
}
#[cfg(test)]
mod tests {
use super::*;
use llm_core::ToolParameters;
use serde_json::json;
#[test]
fn test_message_conversion() {
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
];
let openai_msgs = OpenAIClient::prepare_messages(&messages);
assert_eq!(openai_msgs.len(), 3);
assert_eq!(openai_msgs[0].role, "system");
assert_eq!(openai_msgs[1].role, "user");
assert_eq!(openai_msgs[2].role, "assistant");
}
#[test]
fn test_tool_conversion() {
let tools = vec![Tool::function(
"read_file",
"Read a file's contents",
ToolParameters::object(
json!({
"path": {
"type": "string",
"description": "File path"
}
}),
vec!["path".to_string()],
),
)];
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
assert_eq!(openai_tools.len(), 1);
assert_eq!(openai_tools[0].function.name, "read_file");
assert_eq!(
openai_tools[0].function.description,
"Read a file's contents"
);
}
}

View File

@@ -0,0 +1,12 @@
//! OpenAI GPT API Client
//!
//! Implements the LlmProvider trait for OpenAI's GPT models.
//! Supports both API key authentication and OAuth device flow.
mod auth;
mod client;
mod types;
pub use auth::*;
pub use client::*;
pub use types::*;

View File

@@ -0,0 +1,285 @@
//! OpenAI API request/response types
use serde::{Deserialize, Serialize};
use serde_json::Value;
// ============================================================================
// Request Types
// ============================================================================
#[derive(Debug, Serialize)]
pub struct ChatCompletionRequest<'a> {
pub model: &'a str,
pub messages: Vec<OpenAIMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<&'a [String]>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<OpenAITool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<&'a str>,
pub stream: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIMessage {
pub role: String, // "system", "user", "assistant", "tool"
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<OpenAIToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIToolCall {
pub id: String,
#[serde(rename = "type")]
pub call_type: String,
pub function: OpenAIFunctionCall,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFunctionCall {
pub name: String,
pub arguments: String, // JSON string
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAITool {
#[serde(rename = "type")]
pub tool_type: String,
pub function: OpenAIFunction,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OpenAIFunction {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: Value,
pub required: Vec<String>,
}
// ============================================================================
// Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<Choice>,
pub usage: Option<UsageInfo>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Choice {
pub index: u32,
pub message: OpenAIMessage,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct UsageInfo {
pub prompt_tokens: u32,
pub completion_tokens: u32,
pub total_tokens: u32,
}
// ============================================================================
// Streaming Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ChatCompletionChunk {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<ChunkChoice>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ChunkChoice {
pub index: u32,
pub delta: Delta,
pub finish_reason: Option<String>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct Delta {
#[serde(skip_serializing_if = "Option::is_none")]
pub role: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<DeltaToolCall>>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeltaToolCall {
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
pub call_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub function: Option<DeltaFunction>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DeltaFunction {
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub arguments: Option<String>,
}
// ============================================================================
// Error Response Types
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ErrorResponse {
pub error: ApiError,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ApiError {
pub message: String,
#[serde(rename = "type")]
pub error_type: String,
pub code: Option<String>,
}
// ============================================================================
// Models List Response
// ============================================================================
#[derive(Debug, Clone, Deserialize)]
pub struct ModelsResponse {
pub object: String,
pub data: Vec<ModelData>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ModelData {
pub id: String,
pub object: String,
pub created: u64,
pub owned_by: String,
}
// ============================================================================
// Conversions
// ============================================================================
impl From<&llm_core::Tool> for OpenAITool {
fn from(tool: &llm_core::Tool) -> Self {
Self {
tool_type: "function".to_string(),
function: OpenAIFunction {
name: tool.function.name.clone(),
description: tool.function.description.clone(),
parameters: FunctionParameters {
param_type: tool.function.parameters.param_type.clone(),
properties: tool.function.parameters.properties.clone(),
required: tool.function.parameters.required.clone(),
},
},
}
}
}
impl From<&llm_core::ChatMessage> for OpenAIMessage {
fn from(msg: &llm_core::ChatMessage) -> Self {
use llm_core::Role;
let role = match msg.role {
Role::System => "system",
Role::User => "user",
Role::Assistant => "assistant",
Role::Tool => "tool",
};
// Handle tool result messages
if msg.role == Role::Tool {
return Self {
role: "tool".to_string(),
content: msg.content.clone(),
tool_calls: None,
tool_call_id: msg.tool_call_id.clone(),
name: msg.name.clone(),
};
}
// Handle assistant messages with tool calls
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
calls
.iter()
.map(|call| OpenAIToolCall {
id: call.id.clone(),
call_type: "function".to_string(),
function: OpenAIFunctionCall {
name: call.function.name.clone(),
arguments: serde_json::to_string(&call.function.arguments)
.unwrap_or_else(|_| "{}".to_string()),
},
})
.collect()
});
return Self {
role: "assistant".to_string(),
content: msg.content.clone(),
tool_calls,
tool_call_id: None,
name: None,
};
}
// Simple text message
Self {
role: role.to_string(),
content: msg.content.clone(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
}