feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support: - Add llm-core crate with LlmProvider trait abstraction - Implement Anthropic Claude API client with streaming - Implement OpenAI API client with streaming - Add token counting with SimpleTokenCounter and ClaudeTokenCounter - Add retry logic with exponential backoff and jitter Borderless TUI Redesign: - Rewrite theme system with terminal capability detection (Full/Unicode256/Basic) - Add provider tabs component with keybind switching [1]/[2]/[3] - Implement vim-modal input (Normal/Insert/Visual/Command modes) - Redesign chat panel with timestamps and streaming indicators - Add multi-provider status bar with cost tracking - Add Nerd Font icons with graceful ASCII fallbacks - Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark) Advanced Agent Features: - Add system prompt builder with configurable components - Enhance subagent orchestration with parallel execution - Add git integration module for safe command detection - Add streaming tool results via channels - Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell - Add WebSearch with provider abstraction Plugin System Enhancement: - Add full agent definition parsing from YAML frontmatter - Add skill system with progressive disclosure - Wire plugin hooks into HookManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
18
crates/llm/anthropic/Cargo.toml
Normal file
18
crates/llm/anthropic/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-anthropic"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "Anthropic Claude API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
reqwest-eventsource = "0.6"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time"] }
|
||||
tracing = "0.1"
|
||||
uuid = { version = "1.0", features = ["v4"] }
|
||||
285
crates/llm/anthropic/src/auth.rs
Normal file
285
crates/llm/anthropic/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! Anthropic OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with Anthropic without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for Anthropic device flow
|
||||
pub struct AnthropicAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// Anthropic OAuth endpoints (these would be the real endpoints)
|
||||
const AUTH_BASE_URL: &str = "https://console.anthropic.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl AnthropicAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for AnthropicAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for AnthropicAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api:read api:write", // Request API access
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending), // Treat as pending, caller should slow down
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &AnthropicAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
577
crates/llm/anthropic/src/client.rs
Normal file
577
crates/llm/anthropic/src/client.rs
Normal file
@@ -0,0 +1,577 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the Messages API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, Role, StreamChunk, Tool,
|
||||
ToolCall, ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest_eventsource::{Event, EventSource};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.anthropic.com";
|
||||
const MESSAGES_ENDPOINT: &str = "/v1/messages";
|
||||
const API_VERSION: &str = "2023-06-01";
|
||||
const DEFAULT_MAX_TOKENS: u32 = 8192;
|
||||
|
||||
/// Anthropic Claude API client
|
||||
pub struct AnthropicClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl AnthropicClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "claude-sonnet-4-20250514".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to Anthropic format, extracting system message
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<AnthropicMessage>) {
|
||||
let mut system_content = None;
|
||||
let mut anthropic_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
if msg.role == Role::System {
|
||||
// Collect system messages
|
||||
if let Some(content) = &msg.content {
|
||||
if let Some(existing) = &mut system_content {
|
||||
*existing = format!("{}\n\n{}", existing, content);
|
||||
} else {
|
||||
system_content = Some(content.clone());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
anthropic_messages.push(AnthropicMessage::from(msg));
|
||||
}
|
||||
}
|
||||
|
||||
(system_content, anthropic_messages)
|
||||
}
|
||||
|
||||
/// Convert tools to Anthropic format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<AnthropicTool>> {
|
||||
tools.map(|t| t.iter().map(AnthropicTool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for AnthropicClient {
|
||||
fn name(&self) -> &str {
|
||||
"anthropic"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
// Build the SSE request
|
||||
let req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.header("content-type", "application/json")
|
||||
.json(&request);
|
||||
|
||||
let es = EventSource::new(req).map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
// State for accumulating tool calls across deltas
|
||||
let tool_state: Arc<Mutex<Vec<PartialToolCall>>> = Arc::new(Mutex::new(Vec::new()));
|
||||
|
||||
let stream = es.filter_map(move |event| {
|
||||
let tool_state = Arc::clone(&tool_state);
|
||||
async move {
|
||||
match event {
|
||||
Ok(Event::Open) => None,
|
||||
Ok(Event::Message(msg)) => {
|
||||
// Parse the SSE data as JSON
|
||||
let event: StreamEvent = match serde_json::from_str(&msg.data) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE event: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
convert_stream_event(event, &tool_state).await
|
||||
}
|
||||
Err(reqwest_eventsource::Error::StreamEnded) => None,
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, MESSAGES_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let (system, anthropic_messages) = Self::prepare_messages(messages);
|
||||
let anthropic_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = MessagesRequest {
|
||||
model,
|
||||
messages: anthropic_messages,
|
||||
max_tokens: options.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
|
||||
system: system.as_deref(),
|
||||
temperature: options.temperature,
|
||||
top_p: options.top_p,
|
||||
stop_sequences: options.stop.as_deref(),
|
||||
tools: anthropic_tools,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("x-api-key", bearer)
|
||||
.header("anthropic-version", API_VERSION)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
// Check for rate limiting
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: MessagesResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert response to common format
|
||||
let mut content = String::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for block in api_response.content {
|
||||
match block {
|
||||
ResponseContentBlock::Text { text } => {
|
||||
content.push_str(&text);
|
||||
}
|
||||
ResponseContentBlock::ToolUse { id, name, input } => {
|
||||
tool_calls.push(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name,
|
||||
arguments: input,
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper struct for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
#[allow(dead_code)]
|
||||
id: String,
|
||||
#[allow(dead_code)]
|
||||
name: String,
|
||||
input_json: String,
|
||||
}
|
||||
|
||||
/// Convert an Anthropic stream event to our common StreamChunk format
|
||||
async fn convert_stream_event(
|
||||
event: StreamEvent,
|
||||
tool_state: &Arc<Mutex<Vec<PartialToolCall>>>,
|
||||
) -> Option<Result<StreamChunk, LlmError>> {
|
||||
match event {
|
||||
StreamEvent::ContentBlockStart {
|
||||
index,
|
||||
content_block,
|
||||
} => {
|
||||
match content_block {
|
||||
ContentBlockStartInfo::Text { text } => {
|
||||
if text.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
ContentBlockStartInfo::ToolUse { id, name } => {
|
||||
// Store the tool call start
|
||||
let mut state = tool_state.lock().await;
|
||||
while state.len() <= index {
|
||||
state.push(PartialToolCall::default());
|
||||
}
|
||||
state[index] = PartialToolCall {
|
||||
id: id.clone(),
|
||||
name: name.clone(),
|
||||
input_json: String::new(),
|
||||
};
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: Some(id),
|
||||
function_name: Some(name),
|
||||
arguments_delta: None,
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
StreamEvent::ContentBlockDelta { index, delta } => match delta {
|
||||
ContentDelta::TextDelta { text } => Some(Ok(StreamChunk {
|
||||
content: Some(text),
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: None,
|
||||
})),
|
||||
ContentDelta::InputJsonDelta { partial_json } => {
|
||||
// Accumulate the JSON
|
||||
let mut state = tool_state.lock().await;
|
||||
if index < state.len() {
|
||||
state[index].input_json.push_str(&partial_json);
|
||||
}
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCallDelta {
|
||||
index,
|
||||
id: None,
|
||||
function_name: None,
|
||||
arguments_delta: Some(partial_json),
|
||||
}]),
|
||||
done: false,
|
||||
usage: None,
|
||||
}))
|
||||
}
|
||||
},
|
||||
|
||||
StreamEvent::MessageDelta { usage, .. } => {
|
||||
let u = usage.map(|u| Usage {
|
||||
prompt_tokens: u.input_tokens,
|
||||
completion_tokens: u.output_tokens,
|
||||
total_tokens: u.input_tokens + u.output_tokens,
|
||||
});
|
||||
|
||||
Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: false,
|
||||
usage: u,
|
||||
}))
|
||||
}
|
||||
|
||||
StreamEvent::MessageStop => Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})),
|
||||
|
||||
StreamEvent::Error { error } => Some(Err(LlmError::Api {
|
||||
message: error.message,
|
||||
code: Some(error.error_type),
|
||||
})),
|
||||
|
||||
// Ignore other events
|
||||
StreamEvent::MessageStart { .. }
|
||||
| StreamEvent::ContentBlockStop { .. }
|
||||
| StreamEvent::Ping => None,
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known Claude models with their specifications
|
||||
fn get_claude_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "claude-opus-4-20250514".to_string(),
|
||||
display_name: Some("Claude Opus 4".to_string()),
|
||||
description: Some("Most capable model for complex tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(32_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(75.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-sonnet-4-20250514".to_string(),
|
||||
display_name: Some("Claude Sonnet 4".to_string()),
|
||||
description: Some("Best balance of performance and speed".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(64_000),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(15.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "claude-haiku-3-5-20241022".to_string(),
|
||||
display_name: Some("Claude 3.5 Haiku".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(8_192),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.80),
|
||||
output_price_per_mtok: Some(4.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for AnthropicClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API with a simple request
|
||||
let reachable = if authenticated {
|
||||
// Test with a minimal message to verify auth works
|
||||
let test_messages = vec![ChatMessage::user("Hi")];
|
||||
let test_opts = ChatOptions::new(&self.model).with_max_tokens(1);
|
||||
|
||||
match self.chat(&test_messages, &test_opts, None).await {
|
||||
Ok(_) => true,
|
||||
Err(LlmError::Auth(_)) => false, // Auth failed
|
||||
Err(_) => true, // Other errors mean API is reachable
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let account = if authenticated && reachable {
|
||||
self.account_info().await.ok().flatten()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - run 'owlen login anthropic' to authenticate".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach Anthropic API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "anthropic".to_string(),
|
||||
authenticated,
|
||||
account,
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Anthropic doesn't have a public account info endpoint
|
||||
// Return None - account info would come from OAuth token claims
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Anthropic doesn't expose usage stats via API
|
||||
// This would require the admin/billing API with different auth
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// Return known models - Anthropic doesn't have a models list endpoint
|
||||
Ok(get_claude_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_claude_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let (system, anthropic_msgs) = AnthropicClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(system, Some("You are helpful".to_string()));
|
||||
assert_eq!(anthropic_msgs.len(), 2);
|
||||
assert_eq!(anthropic_msgs[0].role, "user");
|
||||
assert_eq!(anthropic_msgs[1].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let anthropic_tools = AnthropicClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(anthropic_tools.len(), 1);
|
||||
assert_eq!(anthropic_tools[0].name, "read_file");
|
||||
assert_eq!(anthropic_tools[0].description, "Read a file's contents");
|
||||
}
|
||||
}
|
||||
12
crates/llm/anthropic/src/lib.rs
Normal file
12
crates/llm/anthropic/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! Anthropic Claude API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for Anthropic's Claude models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
276
crates/llm/anthropic/src/types.rs
Normal file
276
crates/llm/anthropic/src/types.rs
Normal file
@@ -0,0 +1,276 @@
|
||||
//! Anthropic API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct MessagesRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<AnthropicMessage>,
|
||||
pub max_tokens: u32,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system: Option<&'a str>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop_sequences: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<AnthropicTool>>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicMessage {
|
||||
pub role: String, // "user" or "assistant"
|
||||
pub content: AnthropicContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AnthropicContent {
|
||||
Text(String),
|
||||
Blocks(Vec<ContentBlock>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
is_error: Option<bool>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AnthropicTool {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: ToolInputSchema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolInputSchema {
|
||||
#[serde(rename = "type")]
|
||||
pub schema_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessagesResponse {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub response_type: String,
|
||||
pub role: String,
|
||||
pub content: Vec<ResponseContentBlock>,
|
||||
pub model: String,
|
||||
pub stop_reason: Option<String>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ResponseContentBlock {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: Value,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub input_tokens: u32,
|
||||
pub output_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Event Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum StreamEvent {
|
||||
#[serde(rename = "message_start")]
|
||||
MessageStart { message: MessageStartInfo },
|
||||
|
||||
#[serde(rename = "content_block_start")]
|
||||
ContentBlockStart {
|
||||
index: usize,
|
||||
content_block: ContentBlockStartInfo,
|
||||
},
|
||||
|
||||
#[serde(rename = "content_block_delta")]
|
||||
ContentBlockDelta { index: usize, delta: ContentDelta },
|
||||
|
||||
#[serde(rename = "content_block_stop")]
|
||||
ContentBlockStop { index: usize },
|
||||
|
||||
#[serde(rename = "message_delta")]
|
||||
MessageDelta {
|
||||
delta: MessageDeltaInfo,
|
||||
usage: Option<UsageInfo>,
|
||||
},
|
||||
|
||||
#[serde(rename = "message_stop")]
|
||||
MessageStop,
|
||||
|
||||
#[serde(rename = "ping")]
|
||||
Ping,
|
||||
|
||||
#[serde(rename = "error")]
|
||||
Error { error: ApiError },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageStartInfo {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub message_type: String,
|
||||
pub role: String,
|
||||
pub model: String,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentBlockStartInfo {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse { id: String, name: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
#[serde(tag = "type")]
|
||||
pub enum ContentDelta {
|
||||
#[serde(rename = "text_delta")]
|
||||
TextDelta { text: String },
|
||||
|
||||
#[serde(rename = "input_json_delta")]
|
||||
InputJsonDelta { partial_json: String },
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct MessageDeltaInfo {
|
||||
pub stop_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for AnthropicTool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
input_schema: ToolInputSchema {
|
||||
schema_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for AnthropicMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::User | Role::System => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "user", // Tool results come as user messages in Anthropic
|
||||
};
|
||||
|
||||
// Handle tool results
|
||||
if msg.role == Role::Tool {
|
||||
if let (Some(tool_call_id), Some(content)) = (&msg.tool_call_id, &msg.content) {
|
||||
return Self {
|
||||
role: "user".to_string(),
|
||||
content: AnthropicContent::Blocks(vec![ContentBlock::ToolResult {
|
||||
tool_use_id: tool_call_id.clone(),
|
||||
content: content.clone(),
|
||||
is_error: None,
|
||||
}]),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant {
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
let mut blocks: Vec<ContentBlock> = Vec::new();
|
||||
|
||||
// Add text content if present
|
||||
if let Some(text) = &msg.content {
|
||||
if !text.is_empty() {
|
||||
blocks.push(ContentBlock::Text { text: text.clone() });
|
||||
}
|
||||
}
|
||||
|
||||
// Add tool use blocks
|
||||
for call in tool_calls {
|
||||
blocks.push(ContentBlock::ToolUse {
|
||||
id: call.id.clone(),
|
||||
name: call.function.name.clone(),
|
||||
input: call.function.arguments.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: AnthropicContent::Blocks(blocks),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: AnthropicContent::Text(msg.content.clone().unwrap_or_default()),
|
||||
}
|
||||
}
|
||||
}
|
||||
18
crates/llm/core/Cargo.toml
Normal file
18
crates/llm/core/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "LLM provider abstraction layer for Owlen"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.0", features = ["time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.0", features = ["macros", "rt"] }
|
||||
195
crates/llm/core/examples/token_counting.rs
Normal file
195
crates/llm/core/examples/token_counting.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Token counting example
|
||||
//!
|
||||
//! This example demonstrates how to use the token counting utilities
|
||||
//! to manage LLM context windows.
|
||||
//!
|
||||
//! Run with: cargo run --example token_counting -p llm-core
|
||||
|
||||
use llm_core::{
|
||||
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
println!("=== Token Counting Example ===\n");
|
||||
|
||||
// Example 1: Basic token counting with SimpleTokenCounter
|
||||
println!("1. Basic Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let simple_counter = SimpleTokenCounter::new(8192);
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
|
||||
let token_count = simple_counter.count(text);
|
||||
println!("Text: \"{}\"", text);
|
||||
println!("Estimated tokens: {}", token_count);
|
||||
println!("Max context: {}\n", simple_counter.max_context());
|
||||
|
||||
// Example 2: Counting tokens in chat messages
|
||||
println!("2. Counting Tokens in Chat Messages");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant that provides concise answers."),
|
||||
ChatMessage::user("What is the capital of France?"),
|
||||
ChatMessage::assistant("The capital of France is Paris."),
|
||||
ChatMessage::user("What is its population?"),
|
||||
];
|
||||
|
||||
let total_tokens = simple_counter.count_messages(&messages);
|
||||
println!("Number of messages: {}", messages.len());
|
||||
println!("Total tokens (with overhead): {}\n", total_tokens);
|
||||
|
||||
// Example 3: Using ClaudeTokenCounter for Claude models
|
||||
println!("3. Claude-Specific Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let claude_counter = ClaudeTokenCounter::new();
|
||||
let claude_total = claude_counter.count_messages(&messages);
|
||||
|
||||
println!("Claude counter max context: {}", claude_counter.max_context());
|
||||
println!("Claude estimated tokens: {}\n", claude_total);
|
||||
|
||||
// Example 4: Context window management
|
||||
println!("4. Context Window Management");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut context = ContextWindow::new(8192);
|
||||
println!("Created context window with max: {} tokens", context.max());
|
||||
|
||||
// Simulate adding messages
|
||||
let conversation = vec![
|
||||
ChatMessage::user("Tell me about Rust programming."),
|
||||
ChatMessage::assistant(
|
||||
"Rust is a systems programming language focused on safety, \
|
||||
speed, and concurrency. It prevents common bugs like null pointer \
|
||||
dereferences and data races through its ownership system.",
|
||||
),
|
||||
ChatMessage::user("What are its main features?"),
|
||||
ChatMessage::assistant(
|
||||
"Rust's main features include: 1) Memory safety without garbage collection, \
|
||||
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
|
||||
5) Type inference, and 6) A powerful macro system.",
|
||||
),
|
||||
];
|
||||
|
||||
for (i, msg) in conversation.iter().enumerate() {
|
||||
let tokens = simple_counter.count_messages(&[msg.clone()]);
|
||||
context.add_tokens(tokens);
|
||||
|
||||
let role = msg.role.as_str();
|
||||
let preview = msg
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| {
|
||||
if c.len() > 50 {
|
||||
format!("{}...", &c[..50])
|
||||
} else {
|
||||
c.clone()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
println!(
|
||||
"Message {}: [{}] \"{}\"",
|
||||
i + 1,
|
||||
role,
|
||||
preview
|
||||
);
|
||||
println!(" Added {} tokens", tokens);
|
||||
println!(" Total used: {} / {}", context.used(), context.max());
|
||||
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
|
||||
println!(" Progress: {}\n", context.progress_bar(30));
|
||||
}
|
||||
|
||||
// Example 5: Checking context limits
|
||||
println!("5. Checking Context Limits");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
if context.is_near_limit(0.8) {
|
||||
println!("Warning: Context is over 80% full!");
|
||||
} else {
|
||||
println!("Context usage is below 80%");
|
||||
}
|
||||
|
||||
let remaining = context.remaining();
|
||||
println!("Remaining tokens: {}", remaining);
|
||||
|
||||
let new_message_tokens = 500;
|
||||
if context.has_room_for(new_message_tokens) {
|
||||
println!(
|
||||
"Can fit a message of {} tokens",
|
||||
new_message_tokens
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Cannot fit a message of {} tokens - would need to compact or start new context",
|
||||
new_message_tokens
|
||||
);
|
||||
}
|
||||
|
||||
// Example 6: Different counter variants
|
||||
println!("\n6. Using Different Counter Variants");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let counter_8k = SimpleTokenCounter::default_8k();
|
||||
let counter_32k = SimpleTokenCounter::with_32k();
|
||||
let counter_128k = SimpleTokenCounter::with_128k();
|
||||
|
||||
println!("8k context counter: {} tokens", counter_8k.max_context());
|
||||
println!("32k context counter: {} tokens", counter_32k.max_context());
|
||||
println!("128k context counter: {} tokens", counter_128k.max_context());
|
||||
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
|
||||
println!("\nClaude Haiku: {} tokens", haiku.max_context());
|
||||
println!("Claude Sonnet: {} tokens", sonnet.max_context());
|
||||
println!("Claude Opus: {} tokens", opus.max_context());
|
||||
|
||||
// Example 7: Managing context for a long conversation
|
||||
println!("\n7. Long Conversation Simulation");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
|
||||
let counter = SimpleTokenCounter::new(4096);
|
||||
|
||||
let mut message_count = 0;
|
||||
let mut compaction_count = 0;
|
||||
|
||||
// Simulate 20 exchanges
|
||||
for i in 0..20 {
|
||||
let user_msg = ChatMessage::user(format!(
|
||||
"This is user message number {} asking a question.",
|
||||
i + 1
|
||||
));
|
||||
let assistant_msg = ChatMessage::assistant(format!(
|
||||
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
|
||||
i + 1
|
||||
));
|
||||
|
||||
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
|
||||
|
||||
if !long_context.has_room_for(tokens_needed) {
|
||||
println!(
|
||||
"After {} messages, context is full ({}%). Compacting...",
|
||||
message_count,
|
||||
(long_context.usage_percent() * 100.0) as u32
|
||||
);
|
||||
// In a real scenario, we would compact the conversation
|
||||
// For now, just reset
|
||||
long_context.reset();
|
||||
compaction_count += 1;
|
||||
}
|
||||
|
||||
long_context.add_tokens(tokens_needed);
|
||||
message_count += 2;
|
||||
}
|
||||
|
||||
println!("Total messages: {}", message_count);
|
||||
println!("Compactions needed: {}", compaction_count);
|
||||
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
|
||||
println!("Final progress: {}", long_context.progress_bar(40));
|
||||
|
||||
println!("\n=== Example Complete ===");
|
||||
}
|
||||
796
crates/llm/core/src/lib.rs
Normal file
796
crates/llm/core/src/lib.rs
Normal file
@@ -0,0 +1,796 @@
|
||||
//! LLM Provider Abstraction Layer
|
||||
//!
|
||||
//! This crate defines the common types and traits for LLM provider integration.
|
||||
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
|
||||
//! to enable swapping providers at runtime.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
// ============================================================================
|
||||
// Public Modules
|
||||
// ============================================================================
|
||||
|
||||
pub mod retry;
|
||||
pub mod tokens;
|
||||
|
||||
// Re-export token counting types for convenience
|
||||
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
|
||||
|
||||
// Re-export retry types for convenience
|
||||
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
|
||||
|
||||
// ============================================================================
|
||||
// Error Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LlmError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
Json(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
|
||||
RateLimit { retry_after_secs: Option<u64> },
|
||||
|
||||
#[error("API error: {message}")]
|
||||
Api { message: String, code: Option<String> },
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(String),
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
Stream(String),
|
||||
|
||||
#[error("Request timeout: {0}")]
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Types
|
||||
// ============================================================================
|
||||
|
||||
/// Role of a message in the conversation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Role {
|
||||
fn from(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"system" => Role::System,
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"tool" => Role::Tool,
|
||||
_ => Role::User, // Default fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A message in the conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls made by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
/// For tool role messages: the ID of the tool call this responds to
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
/// For tool role messages: the name of the tool
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Create a system message
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message with tool calls (no text content)
|
||||
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tool result message
|
||||
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Types
|
||||
// ============================================================================
|
||||
|
||||
/// A tool call requested by the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
|
||||
/// The type of tool call (always "function" for now)
|
||||
#[serde(rename = "type", default = "default_function_type")]
|
||||
pub call_type: String,
|
||||
|
||||
/// The function being called
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
fn default_function_type() -> String {
|
||||
"function".to_string()
|
||||
}
|
||||
|
||||
/// Details of a function call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
|
||||
/// Arguments as a JSON object
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Definition of a tool available to the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Create a new function tool
|
||||
pub fn function(
|
||||
name: impl Into<String>,
|
||||
description: impl Into<String>,
|
||||
parameters: ToolParameters,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function definition within a tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
/// Parameters schema for a function
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
|
||||
/// JSON Schema properties object
|
||||
pub properties: Value,
|
||||
|
||||
/// Required parameter names
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
impl ToolParameters {
|
||||
/// Create an object parameter schema
|
||||
pub fn object(properties: Value, required: Vec<String>) -> Self {
|
||||
Self {
|
||||
param_type: "object".to_string(),
|
||||
properties,
|
||||
required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// A chunk of a streaming response
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
/// Incremental text content
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls (may be partial/streaming)
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
|
||||
/// Whether this is the final chunk
|
||||
pub done: bool,
|
||||
|
||||
/// Usage statistics (typically only in final chunk)
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDelta {
|
||||
/// Index of this tool call in the array
|
||||
pub index: usize,
|
||||
|
||||
/// Tool call ID (may only be present in first delta)
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Function name (may only be present in first delta)
|
||||
pub function_name: Option<String>,
|
||||
|
||||
/// Incremental arguments string
|
||||
pub arguments_delta: Option<String>,
|
||||
}
|
||||
|
||||
/// Token usage statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Options for a chat request
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatOptions {
|
||||
/// Model to use
|
||||
pub model: String,
|
||||
|
||||
/// Temperature (0.0 - 2.0)
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Maximum tokens to generate
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// Top-p sampling
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl ChatOptions {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temp: f32) -> Self {
|
||||
self.temperature = Some(temp);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max: u32) -> Self {
|
||||
self.max_tokens = Some(max);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Trait
|
||||
// ============================================================================
|
||||
|
||||
/// A boxed stream of chunks
|
||||
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
|
||||
|
||||
/// The main trait that all LLM providers must implement
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get the current model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Send a chat request and receive a streaming response
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The conversation history
|
||||
/// * `options` - Request options (model, temperature, etc.)
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
///
|
||||
/// # Returns
|
||||
/// A stream of response chunks
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError>;
|
||||
|
||||
/// Send a chat request and receive a complete response (non-streaming)
|
||||
///
|
||||
/// Default implementation collects the stream, but providers may override
|
||||
/// for efficiency.
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut stream = self.chat_stream(messages, options, tools).await?;
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
if let Some(text) = chunk.content {
|
||||
content.push_str(&text);
|
||||
}
|
||||
|
||||
if let Some(deltas) = chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
// Grow the tool_calls vec if needed
|
||||
while tool_calls.len() <= delta.index {
|
||||
tool_calls.push(PartialToolCall::default());
|
||||
}
|
||||
|
||||
let partial = &mut tool_calls[delta.index];
|
||||
if let Some(id) = delta.id {
|
||||
partial.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.function_name {
|
||||
partial.function_name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
partial.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.usage.is_some() {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert partial tool calls to complete tool calls
|
||||
let final_tool_calls: Vec<ToolCall> = tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|p| p.try_into_tool_call())
|
||||
.collect();
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(final_tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete chat response (non-streaming)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Helper for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
function_name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl PartialToolCall {
|
||||
fn try_into_tool_call(self) -> Option<ToolCall> {
|
||||
let id = self.id?;
|
||||
let name = self.function_name?;
|
||||
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
|
||||
|
||||
Some(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Authentication
|
||||
// ============================================================================
|
||||
|
||||
/// Authentication method for LLM providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthMethod {
|
||||
/// No authentication (for local providers like Ollama)
|
||||
None,
|
||||
|
||||
/// API key authentication
|
||||
ApiKey(String),
|
||||
|
||||
/// OAuth access token (from login flow)
|
||||
OAuth {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AuthMethod {
|
||||
/// Create API key auth
|
||||
pub fn api_key(key: impl Into<String>) -> Self {
|
||||
Self::ApiKey(key.into())
|
||||
}
|
||||
|
||||
/// Create OAuth auth from tokens
|
||||
pub fn oauth(access_token: impl Into<String>) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create OAuth auth with refresh token
|
||||
pub fn oauth_with_refresh(
|
||||
access_token: impl Into<String>,
|
||||
refresh_token: impl Into<String>,
|
||||
expires_at: Option<u64>,
|
||||
) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: Some(refresh_token.into()),
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bearer token for Authorization header
|
||||
pub fn bearer_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::None => None,
|
||||
Self::ApiKey(key) => Some(key),
|
||||
Self::OAuth { access_token, .. } => Some(access_token),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if token might need refresh
|
||||
pub fn needs_refresh(&self) -> bool {
|
||||
match self {
|
||||
Self::OAuth {
|
||||
expires_at: Some(exp),
|
||||
refresh_token: Some(_),
|
||||
..
|
||||
} => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
// Refresh if expiring within 5 minutes
|
||||
*exp < now + 300
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device code response for OAuth device flow
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeviceCodeResponse {
|
||||
/// Code the user enters on the verification page
|
||||
pub user_code: String,
|
||||
|
||||
/// URL the user visits to authorize
|
||||
pub verification_uri: String,
|
||||
|
||||
/// Full URL with code pre-filled (if supported)
|
||||
pub verification_uri_complete: Option<String>,
|
||||
|
||||
/// Device code for polling (internal use)
|
||||
pub device_code: String,
|
||||
|
||||
/// How often to poll (in seconds)
|
||||
pub interval: u64,
|
||||
|
||||
/// When the codes expire (in seconds)
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Result of polling for device authorization
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DeviceAuthResult {
|
||||
/// Still waiting for user to authorize
|
||||
Pending,
|
||||
|
||||
/// User authorized, here are the tokens
|
||||
Success {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_in: Option<u64>,
|
||||
},
|
||||
|
||||
/// User denied authorization
|
||||
Denied,
|
||||
|
||||
/// Code expired
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Trait for providers that support OAuth device flow
|
||||
#[async_trait]
|
||||
pub trait OAuthProvider {
|
||||
/// Start the device authorization flow
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
|
||||
|
||||
/// Poll for the authorization result
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
|
||||
|
||||
/// Refresh an access token using a refresh token
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
|
||||
}
|
||||
|
||||
/// Stored credentials for a provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredCredentials {
|
||||
pub provider: String,
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Status & Info
|
||||
// ============================================================================
|
||||
|
||||
/// Status information for a provider connection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStatus {
|
||||
/// Provider name
|
||||
pub provider: String,
|
||||
|
||||
/// Whether the connection is authenticated
|
||||
pub authenticated: bool,
|
||||
|
||||
/// Current user/account info if authenticated
|
||||
pub account: Option<AccountInfo>,
|
||||
|
||||
/// Current model being used
|
||||
pub model: String,
|
||||
|
||||
/// API endpoint URL
|
||||
pub endpoint: String,
|
||||
|
||||
/// Whether the provider is reachable
|
||||
pub reachable: bool,
|
||||
|
||||
/// Any status message or error
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Account/user information from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccountInfo {
|
||||
/// Account/user ID
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Display name or email
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Account email
|
||||
pub email: Option<String>,
|
||||
|
||||
/// Account type (free, pro, team, enterprise)
|
||||
pub account_type: Option<String>,
|
||||
|
||||
/// Organization name if applicable
|
||||
pub organization: Option<String>,
|
||||
}
|
||||
|
||||
/// Usage statistics from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsageStats {
|
||||
/// Total tokens used in current period
|
||||
pub tokens_used: Option<u64>,
|
||||
|
||||
/// Token limit for current period (if applicable)
|
||||
pub token_limit: Option<u64>,
|
||||
|
||||
/// Number of requests made
|
||||
pub requests_made: Option<u64>,
|
||||
|
||||
/// Request limit (if applicable)
|
||||
pub request_limit: Option<u64>,
|
||||
|
||||
/// Cost incurred (if available)
|
||||
pub cost_usd: Option<f64>,
|
||||
|
||||
/// Period start timestamp
|
||||
pub period_start: Option<u64>,
|
||||
|
||||
/// Period end timestamp
|
||||
pub period_end: Option<u64>,
|
||||
}
|
||||
|
||||
/// Available model information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model ID/name
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable display name
|
||||
pub display_name: Option<String>,
|
||||
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Context window size (tokens)
|
||||
pub context_window: Option<u32>,
|
||||
|
||||
/// Max output tokens
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Whether the model supports tool use
|
||||
pub supports_tools: bool,
|
||||
|
||||
/// Whether the model supports vision/images
|
||||
pub supports_vision: bool,
|
||||
|
||||
/// Input token price per 1M tokens (USD)
|
||||
pub input_price_per_mtok: Option<f64>,
|
||||
|
||||
/// Output token price per 1M tokens (USD)
|
||||
pub output_price_per_mtok: Option<f64>,
|
||||
}
|
||||
|
||||
/// Trait for providers that support status/info queries
|
||||
#[async_trait]
|
||||
pub trait ProviderInfo {
|
||||
/// Get the current connection status
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError>;
|
||||
|
||||
/// Get account information (if authenticated)
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
|
||||
|
||||
/// Get usage statistics (if available)
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
|
||||
|
||||
/// List available models
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
|
||||
|
||||
/// Check if a specific model is available
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = self.list_models().await?;
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Supported LLM providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProviderType {
|
||||
Ollama,
|
||||
Anthropic,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ollama" => Some(Self::Ollama),
|
||||
"anthropic" | "claude" => Some(Self::Anthropic),
|
||||
"openai" | "gpt" => Some(Self::OpenAI),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "ollama",
|
||||
Self::Anthropic => "anthropic",
|
||||
Self::OpenAI => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
/// Default model for this provider
|
||||
pub fn default_model(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "qwen3:8b",
|
||||
Self::Anthropic => "claude-sonnet-4-20250514",
|
||||
Self::OpenAI => "gpt-4o",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProviderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
386
crates/llm/core/src/retry.rs
Normal file
386
crates/llm/core/src/retry.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Error recovery and retry logic for LLM operations
|
||||
//!
|
||||
//! This module provides configurable retry strategies with exponential backoff
|
||||
//! for handling transient failures when communicating with LLM providers.
|
||||
|
||||
use crate::LlmError;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for retry behavior
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: u32,
|
||||
/// Initial delay before first retry (in milliseconds)
|
||||
pub initial_delay_ms: u64,
|
||||
/// Maximum delay between retries (in milliseconds)
|
||||
pub max_delay_ms: u64,
|
||||
/// Multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 30000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Create a new retry configuration with custom values
|
||||
pub fn new(
|
||||
max_retries: u32,
|
||||
initial_delay_ms: u64,
|
||||
max_delay_ms: u64,
|
||||
backoff_multiplier: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
max_retries,
|
||||
initial_delay_ms,
|
||||
max_delay_ms,
|
||||
backoff_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with no retries
|
||||
pub fn no_retry() -> Self {
|
||||
Self {
|
||||
max_retries: 0,
|
||||
initial_delay_ms: 0,
|
||||
max_delay_ms: 0,
|
||||
backoff_multiplier: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with aggressive retries for rate-limited scenarios
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
max_retries: 5,
|
||||
initial_delay_ms: 2000,
|
||||
max_delay_ms: 60000,
|
||||
backoff_multiplier: 2.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether an error is retryable
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `error` - The error to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the error is transient and the operation should be retried,
|
||||
/// `false` if the error is permanent and retrying won't help
|
||||
pub fn is_retryable_error(error: &LlmError) -> bool {
|
||||
match error {
|
||||
// Always retry rate limits
|
||||
LlmError::RateLimit { .. } => true,
|
||||
|
||||
// Always retry timeouts
|
||||
LlmError::Timeout(_) => true,
|
||||
|
||||
// Retry HTTP errors that are server-side (5xx)
|
||||
LlmError::Http(msg) => {
|
||||
// Check if the error message contains a 5xx status code
|
||||
msg.contains("500")
|
||||
|| msg.contains("502")
|
||||
|| msg.contains("503")
|
||||
|| msg.contains("504")
|
||||
|| msg.contains("Internal Server Error")
|
||||
|| msg.contains("Bad Gateway")
|
||||
|| msg.contains("Service Unavailable")
|
||||
|| msg.contains("Gateway Timeout")
|
||||
}
|
||||
|
||||
// Don't retry authentication errors - they need user intervention
|
||||
LlmError::Auth(_) => false,
|
||||
|
||||
// Don't retry JSON parsing errors - the data is malformed
|
||||
LlmError::Json(_) => false,
|
||||
|
||||
// Don't retry API errors - these are typically client-side issues
|
||||
LlmError::Api { .. } => false,
|
||||
|
||||
// Provider errors might be transient, but we conservatively don't retry
|
||||
LlmError::Provider(_) => false,
|
||||
|
||||
// Stream errors are typically not retryable
|
||||
LlmError::Stream(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for retrying failed operations with exponential backoff
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryStrategy {
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryStrategy {
|
||||
/// Create a new retry strategy with the given configuration
|
||||
pub fn new(config: RetryConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Create a retry strategy with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(RetryConfig::default())
|
||||
}
|
||||
|
||||
/// Execute an async operation with retries
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - A function that returns a Future producing a Result
|
||||
///
|
||||
/// # Returns
|
||||
/// The result of the operation, or the last error if all retries fail
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let strategy = RetryStrategy::default_config();
|
||||
/// let result = strategy.execute(|| async {
|
||||
/// // Your LLM API call here
|
||||
/// llm_client.chat(&messages, &options, None).await
|
||||
/// }).await?;
|
||||
/// ```
|
||||
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, LlmError>>,
|
||||
{
|
||||
let mut attempt = 0;
|
||||
|
||||
loop {
|
||||
// Try the operation
|
||||
match operation().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(err) => {
|
||||
// Check if we should retry
|
||||
if !is_retryable_error(&err) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
|
||||
// Check if we've exhausted retries
|
||||
if attempt > self.config.max_retries {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff and jitter
|
||||
let delay = self.delay_for_attempt(attempt);
|
||||
|
||||
// Log retry attempt (in a real implementation, you might use tracing)
|
||||
eprintln!(
|
||||
"Retry attempt {}/{} after {:?}",
|
||||
attempt, self.config.max_retries, delay
|
||||
);
|
||||
|
||||
// Sleep before next attempt
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the delay for a given attempt number with jitter
|
||||
///
|
||||
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
|
||||
/// Adds random jitter of ±10% to prevent thundering herd problems
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `attempt` - The attempt number (1-indexed)
|
||||
///
|
||||
/// # Returns
|
||||
/// The delay duration to wait before the next retry
|
||||
fn delay_for_attempt(&self, attempt: u32) -> Duration {
|
||||
// Calculate base delay with exponential backoff
|
||||
let base_delay_ms = self.config.initial_delay_ms as f64
|
||||
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
|
||||
|
||||
// Cap at max_delay_ms
|
||||
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
|
||||
|
||||
// Add jitter: ±10%
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_factor = rng.gen_range(0.9..=1.1);
|
||||
let final_delay_ms = capped_delay_ms * jitter_factor;
|
||||
|
||||
Duration::from_millis(final_delay_ms as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_default_retry_config() {
|
||||
let config = RetryConfig::default();
|
||||
assert_eq!(config.max_retries, 3);
|
||||
assert_eq!(config.initial_delay_ms, 1000);
|
||||
assert_eq!(config.max_delay_ms, 30000);
|
||||
assert_eq!(config.backoff_multiplier, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_retry_config() {
|
||||
let config = RetryConfig::no_retry();
|
||||
assert_eq!(config.max_retries, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_retryable_error() {
|
||||
// Retryable errors
|
||||
assert!(is_retryable_error(&LlmError::RateLimit {
|
||||
retry_after_secs: Some(60)
|
||||
}));
|
||||
assert!(is_retryable_error(&LlmError::Timeout(
|
||||
"Request timed out".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"500 Internal Server Error".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"503 Service Unavailable".to_string()
|
||||
)));
|
||||
|
||||
// Non-retryable errors
|
||||
assert!(!is_retryable_error(&LlmError::Auth(
|
||||
"Invalid API key".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Json(
|
||||
"Invalid JSON".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Api {
|
||||
message: "Invalid request".to_string(),
|
||||
code: Some("400".to_string())
|
||||
}));
|
||||
assert!(!is_retryable_error(&LlmError::Http(
|
||||
"400 Bad Request".to_string()
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_calculation() {
|
||||
let config = RetryConfig::default();
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Test that delays increase exponentially
|
||||
let delay1 = strategy.delay_for_attempt(1);
|
||||
let delay2 = strategy.delay_for_attempt(2);
|
||||
let delay3 = strategy.delay_for_attempt(3);
|
||||
|
||||
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
|
||||
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
|
||||
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
|
||||
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_max_cap() {
|
||||
let config = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 5000,
|
||||
backoff_multiplier: 2.0,
|
||||
};
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Even with high attempt numbers, delay should be capped
|
||||
let delay = strategy.delay_for_attempt(10);
|
||||
assert!(delay.as_millis() <= 5500); // max + jitter
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_on_first_attempt() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok::<_, LlmError>(42)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_after_retries() {
|
||||
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if current < 3 {
|
||||
Err(LlmError::Timeout("Timeout".to_string()))
|
||||
} else {
|
||||
Ok(42)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_exhausted() {
|
||||
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_retryable_error() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
|
||||
}
|
||||
}
|
||||
607
crates/llm/core/src/tokens.rs
Normal file
607
crates/llm/core/src/tokens.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
//! Token counting utilities for LLM context management
|
||||
//!
|
||||
//! This module provides token counting abstractions and implementations for
|
||||
//! managing LLM context windows. Token counters estimate token usage without
|
||||
//! requiring external tokenization libraries, using heuristic-based approaches.
|
||||
|
||||
use crate::ChatMessage;
|
||||
|
||||
// ============================================================================
|
||||
// TokenCounter Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for counting tokens in text and chat messages
|
||||
///
|
||||
/// Implementations provide model-specific token counting logic to help
|
||||
/// manage context windows and estimate API costs.
|
||||
pub trait TokenCounter: Send + Sync {
|
||||
/// Count tokens in a string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `text` - The text to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated number of tokens
|
||||
fn count(&self, text: &str) -> usize;
|
||||
|
||||
/// Count tokens in chat messages
|
||||
///
|
||||
/// This accounts for both the message content and the overhead
|
||||
/// from the chat message structure (roles, delimiters, etc.).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The messages to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated total tokens including message structure overhead
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
|
||||
|
||||
/// Get the model's max context window size
|
||||
///
|
||||
/// # Returns
|
||||
/// Maximum number of tokens the model can handle
|
||||
fn max_context(&self) -> usize;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SimpleTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// A basic token counter using simple heuristics
|
||||
///
|
||||
/// This counter uses the rule of thumb that English text averages about
|
||||
/// 4 characters per token. It adds overhead for message structure.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let text = "Hello, world!";
|
||||
/// let tokens = counter.count(text);
|
||||
/// assert!(tokens > 0);
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("What is the weather?"),
|
||||
/// ChatMessage::assistant("I don't have access to weather data."),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// assert!(total > 0);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimpleTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl SimpleTokenCounter {
|
||||
/// Create a new simple token counter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size for the model
|
||||
pub fn new(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a token counter with a default 8192 token context
|
||||
pub fn default_8k() -> Self {
|
||||
Self::new(8192)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 32k token context
|
||||
pub fn with_32k() -> Self {
|
||||
Self::new(32768)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 128k token context
|
||||
pub fn with_128k() -> Self {
|
||||
Self::new(131072)
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for SimpleTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Estimate: approximately 4 characters per token for English
|
||||
// Add 3 before dividing to round up
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Base overhead for message formatting (estimated)
|
||||
// Each message has role, delimiters, etc.
|
||||
const MESSAGE_OVERHEAD: usize = 4;
|
||||
|
||||
for msg in messages {
|
||||
// Count role
|
||||
total += MESSAGE_OVERHEAD;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls (more expensive due to JSON structure)
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
// ID overhead
|
||||
total += self.count(&tc.id);
|
||||
// Function name
|
||||
total += self.count(&tc.function.name);
|
||||
// Arguments (JSON serialized, add 20% overhead for JSON structure)
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Count tool call id for tool result messages
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
// Count tool name for tool result messages
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ClaudeTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// Token counter optimized for Anthropic Claude models
|
||||
///
|
||||
/// Claude models have specific tokenization characteristics and overhead.
|
||||
/// This counter adjusts the estimates accordingly.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = ClaudeTokenCounter::new();
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::system("You are a helpful assistant."),
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl ClaudeTokenCounter {
|
||||
/// Create a new Claude token counter with default 200k context
|
||||
///
|
||||
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_context: 200_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Claude counter with a custom context window
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size
|
||||
pub fn with_context(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3 Haiku (200k context)
|
||||
pub fn haiku() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3.5 Sonnet (200k context)
|
||||
pub fn sonnet() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 4 Opus (200k context)
|
||||
pub fn opus() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeTokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for ClaudeTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Claude's tokenization is similar to the 4 chars/token heuristic
|
||||
// but tends to be slightly more efficient with structured content
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Claude has specific message formatting overhead
|
||||
const MESSAGE_OVERHEAD: usize = 5;
|
||||
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
|
||||
|
||||
for msg in messages {
|
||||
// Different overhead for system vs other messages
|
||||
let overhead = if matches!(msg.role, crate::Role::System) {
|
||||
SYSTEM_MESSAGE_OVERHEAD
|
||||
} else {
|
||||
MESSAGE_OVERHEAD
|
||||
};
|
||||
|
||||
total += overhead;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
// Claude's tool call format has additional overhead
|
||||
const TOOL_CALL_OVERHEAD: usize = 10;
|
||||
|
||||
for tc in tool_calls {
|
||||
total += TOOL_CALL_OVERHEAD;
|
||||
total += self.count(&tc.id);
|
||||
total += self.count(&tc.function.name);
|
||||
|
||||
// Arguments with JSON structure overhead
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool result overhead
|
||||
if msg.tool_call_id.is_some() {
|
||||
const TOOL_RESULT_OVERHEAD: usize = 8;
|
||||
total += TOOL_RESULT_OVERHEAD;
|
||||
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ContextWindow
|
||||
// ============================================================================
|
||||
|
||||
/// Manages context window tracking for a conversation
|
||||
///
|
||||
/// Helps monitor token usage and determine when context limits are approaching.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let mut window = ContextWindow::new(counter.max_context());
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ChatMessage::assistant("Hi there!"),
|
||||
/// ];
|
||||
///
|
||||
/// let tokens = counter.count_messages(&messages);
|
||||
/// window.add_tokens(tokens);
|
||||
///
|
||||
/// println!("Used: {} tokens", window.used());
|
||||
/// println!("Remaining: {} tokens", window.remaining());
|
||||
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
|
||||
///
|
||||
/// if window.is_near_limit(0.8) {
|
||||
/// println!("Warning: Context is 80% full!");
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextWindow {
|
||||
/// Number of tokens currently used
|
||||
used: usize,
|
||||
/// Maximum number of tokens allowed
|
||||
max: usize,
|
||||
}
|
||||
|
||||
impl ContextWindow {
|
||||
/// Create a new context window tracker
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size in tokens
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self { used: 0, max }
|
||||
}
|
||||
|
||||
/// Create a context window with initial usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size
|
||||
/// * `used` - Initial number of tokens used
|
||||
pub fn with_usage(max: usize, used: usize) -> Self {
|
||||
Self { used, max }
|
||||
}
|
||||
|
||||
/// Get the number of tokens currently used
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Get the maximum number of tokens
|
||||
pub fn max(&self) -> usize {
|
||||
self.max
|
||||
}
|
||||
|
||||
/// Get the number of remaining tokens
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.max.saturating_sub(self.used)
|
||||
}
|
||||
|
||||
/// Get the usage as a percentage (0.0 to 1.0)
|
||||
///
|
||||
/// Returns the fraction of the context window that is currently used.
|
||||
pub fn usage_percent(&self) -> f32 {
|
||||
if self.max == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.used as f32 / self.max as f32
|
||||
}
|
||||
|
||||
/// Check if usage is near the limit
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
|
||||
/// 0.8 means "is usage > 80%?"
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the current usage exceeds the threshold percentage
|
||||
pub fn is_near_limit(&self, threshold: f32) -> bool {
|
||||
self.usage_percent() > threshold
|
||||
}
|
||||
|
||||
/// Add tokens to the usage count
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens to add
|
||||
pub fn add_tokens(&mut self, tokens: usize) {
|
||||
self.used = self.used.saturating_add(tokens);
|
||||
}
|
||||
|
||||
/// Set the current usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `used` - Number of tokens currently used
|
||||
pub fn set_used(&mut self, used: usize) {
|
||||
self.used = used;
|
||||
}
|
||||
|
||||
/// Reset the usage counter to zero
|
||||
pub fn reset(&mut self) {
|
||||
self.used = 0;
|
||||
}
|
||||
|
||||
/// Check if there's enough room for additional tokens
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens needed
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if adding these tokens would stay within the limit
|
||||
pub fn has_room_for(&self, tokens: usize) -> bool {
|
||||
self.used.saturating_add(tokens) <= self.max
|
||||
}
|
||||
|
||||
/// Get a visual progress bar representation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `width` - Width of the progress bar in characters
|
||||
///
|
||||
/// # Returns
|
||||
/// A string with a simple text-based progress bar
|
||||
pub fn progress_bar(&self, width: usize) -> String {
|
||||
if width == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let percent = self.usage_percent();
|
||||
let filled = ((percent * width as f32) as usize).min(width);
|
||||
let empty = width - filled;
|
||||
|
||||
format!(
|
||||
"[{}{}] {:.1}%",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(empty),
|
||||
percent * 100.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ChatMessage, FunctionCall, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_basic() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
// Empty string
|
||||
assert_eq!(counter.count(""), 0);
|
||||
|
||||
// Short string (~4 chars/token)
|
||||
let text = "Hello, world!"; // 13 chars -> ~4 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 3 && count <= 5);
|
||||
|
||||
// Longer text
|
||||
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 10 && count <= 13);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_messages() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello!"),
|
||||
ChatMessage::assistant("Hi there! How can I help you today?"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
|
||||
// Should be more than just the text due to overhead
|
||||
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
|
||||
assert!(total > text_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_with_tool_calls() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "read_file".to_string(),
|
||||
arguments: json!({"path": "/etc/hosts"}),
|
||||
},
|
||||
};
|
||||
|
||||
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
assert_eq!(counter.max_context(), 200_000);
|
||||
|
||||
let text = "Hello, Claude!";
|
||||
let count = counter.count(text);
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_system_message() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant."),
|
||||
ChatMessage::user("Hello!"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window() {
|
||||
let mut window = ContextWindow::new(1000);
|
||||
|
||||
assert_eq!(window.used(), 0);
|
||||
assert_eq!(window.max(), 1000);
|
||||
assert_eq!(window.remaining(), 1000);
|
||||
assert_eq!(window.usage_percent(), 0.0);
|
||||
|
||||
window.add_tokens(200);
|
||||
assert_eq!(window.used(), 200);
|
||||
assert_eq!(window.remaining(), 800);
|
||||
assert_eq!(window.usage_percent(), 0.2);
|
||||
|
||||
window.add_tokens(600);
|
||||
assert_eq!(window.used(), 800);
|
||||
assert!(window.is_near_limit(0.7));
|
||||
assert!(!window.is_near_limit(0.9));
|
||||
|
||||
assert!(window.has_room_for(200));
|
||||
assert!(!window.has_room_for(300));
|
||||
|
||||
window.reset();
|
||||
assert_eq!(window.used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_progress_bar() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
window.add_tokens(50);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("====="));
|
||||
assert!(bar.contains("50.0%"));
|
||||
|
||||
window.add_tokens(40);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("========="));
|
||||
assert!(bar.contains("90.0%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_saturation() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
// Adding more tokens than max should saturate, not overflow
|
||||
window.add_tokens(150);
|
||||
assert_eq!(window.used(), 150);
|
||||
assert_eq!(window.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_constructors() {
|
||||
let counter1 = SimpleTokenCounter::default_8k();
|
||||
assert_eq!(counter1.max_context(), 8192);
|
||||
|
||||
let counter2 = SimpleTokenCounter::with_32k();
|
||||
assert_eq!(counter2.max_context(), 32768);
|
||||
|
||||
let counter3 = SimpleTokenCounter::with_128k();
|
||||
assert_eq!(counter3.max_context(), 131072);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_variants() {
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
assert_eq!(haiku.max_context(), 200_000);
|
||||
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
assert_eq!(sonnet.max_context(), 200_000);
|
||||
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
assert_eq!(opus.max_context(), 200_000);
|
||||
|
||||
let custom = ClaudeTokenCounter::with_context(100_000);
|
||||
assert_eq!(custom.max_context(), 100_000);
|
||||
}
|
||||
}
|
||||
@@ -6,11 +6,13 @@ license.workspace = true
|
||||
rust-version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
tokio = { version = "1.39", features = ["rt-multi-thread"] }
|
||||
tokio = { version = "1.39", features = ["rt-multi-thread", "macros"] }
|
||||
futures = "0.3"
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "1"
|
||||
bytes = "1"
|
||||
tokio-stream = "0.1.17"
|
||||
async-trait = "0.1"
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use crate::types::{ChatMessage, ChatResponseChunk, Tool};
|
||||
use futures::{Stream, TryStreamExt};
|
||||
use futures::{Stream, StreamExt, TryStreamExt};
|
||||
use reqwest::Client;
|
||||
use serde::Serialize;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use thiserror::Error;
|
||||
use async_trait::async_trait;
|
||||
use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError, ChatOptions, ChunkStream,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct OllamaClient {
|
||||
http: Client,
|
||||
base_url: String, // e.g. "http://localhost:11434"
|
||||
api_key: Option<String>, // For Ollama Cloud authentication
|
||||
current_model: String, // Default model for this client
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
@@ -27,12 +33,24 @@ pub enum OllamaError {
|
||||
Protocol(String),
|
||||
}
|
||||
|
||||
// Convert OllamaError to LlmError
|
||||
impl From<OllamaError> for LlmError {
|
||||
fn from(err: OllamaError) -> Self {
|
||||
match err {
|
||||
OllamaError::Http(e) => LlmError::Http(e.to_string()),
|
||||
OllamaError::Json(e) => LlmError::Json(e.to_string()),
|
||||
OllamaError::Protocol(msg) => LlmError::Provider(msg),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OllamaClient {
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
base_url: base_url.into().trim_end_matches('/').to_string(),
|
||||
api_key: None,
|
||||
current_model: "qwen3:8b".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -41,12 +59,17 @@ impl OllamaClient {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.current_model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_cloud() -> Self {
|
||||
// Same API, different base
|
||||
Self::new("https://ollama.com")
|
||||
}
|
||||
|
||||
pub async fn chat_stream(
|
||||
pub async fn chat_stream_raw(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
opts: &OllamaOptions,
|
||||
@@ -99,3 +122,208 @@ impl OllamaClient {
|
||||
Ok(out)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// LlmProvider Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OllamaClient {
|
||||
fn name(&self) -> &str {
|
||||
"ollama"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.current_model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[llm_core::ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[llm_core::Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
// Convert llm_core messages to Ollama messages
|
||||
let ollama_messages: Vec<ChatMessage> = messages.iter().map(|m| m.into()).collect();
|
||||
|
||||
// Convert llm_core tools to Ollama tools if present
|
||||
let ollama_tools: Option<Vec<Tool>> = tools.map(|tools| {
|
||||
tools.iter().map(|t| Tool {
|
||||
tool_type: t.tool_type.clone(),
|
||||
function: crate::types::ToolFunction {
|
||||
name: t.function.name.clone(),
|
||||
description: t.function.description.clone(),
|
||||
parameters: crate::types::ToolParameters {
|
||||
param_type: t.function.parameters.param_type.clone(),
|
||||
properties: t.function.parameters.properties.clone(),
|
||||
required: t.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
let opts = OllamaOptions {
|
||||
model: options.model.clone(),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
// Make the request and build the body inline to avoid lifetime issues
|
||||
#[derive(Serialize)]
|
||||
struct Body<'a> {
|
||||
model: &'a str,
|
||||
messages: &'a [ChatMessage],
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<&'a [Tool]>,
|
||||
}
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let body = Body {
|
||||
model: &opts.model,
|
||||
messages: &ollama_messages,
|
||||
stream: true,
|
||||
tools: ollama_tools.as_deref(),
|
||||
};
|
||||
|
||||
let mut req = self.http.post(url).json(&body);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
let bytes_stream = resp.bytes_stream();
|
||||
|
||||
// NDJSON parser: split by '\n', parse each as JSON and stream the results
|
||||
let converted_stream = bytes_stream
|
||||
.map(|result| {
|
||||
result.map_err(|e| LlmError::Http(e.to_string()))
|
||||
})
|
||||
.map_ok(|bytes| {
|
||||
// Convert the chunk to a UTF-8 string and own it
|
||||
let txt = String::from_utf8_lossy(&bytes).into_owned();
|
||||
// Parse each non-empty line into a ChatResponseChunk
|
||||
let results: Vec<Result<llm_core::StreamChunk, LlmError>> = txt
|
||||
.lines()
|
||||
.filter_map(|line| {
|
||||
let trimmed = line.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
serde_json::from_str::<ChatResponseChunk>(trimmed)
|
||||
.map(|chunk| llm_core::StreamChunk::from(chunk))
|
||||
.map_err(|e| LlmError::Json(e.to_string())),
|
||||
)
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
futures::stream::iter(results)
|
||||
})
|
||||
.try_flatten();
|
||||
|
||||
Ok(Box::pin(converted_stream))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Trait Implementation
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelList {
|
||||
models: Vec<OllamaModel>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModel {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
modified_at: Option<String>,
|
||||
#[serde(default)]
|
||||
size: Option<u64>,
|
||||
#[serde(default)]
|
||||
digest: Option<String>,
|
||||
#[serde(default)]
|
||||
details: Option<OllamaModelDetails>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct OllamaModelDetails {
|
||||
#[serde(default)]
|
||||
format: Option<String>,
|
||||
#[serde(default)]
|
||||
family: Option<String>,
|
||||
#[serde(default)]
|
||||
parameter_size: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OllamaClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
// Try to ping the Ollama server
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let reachable = self.http.get(&url).send().await.is_ok();
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "ollama".to_string(),
|
||||
authenticated: self.api_key.is_some(),
|
||||
account: None, // Ollama is local, no account info
|
||||
model: self.current_model.clone(),
|
||||
endpoint: self.base_url.clone(),
|
||||
reachable,
|
||||
message: if reachable {
|
||||
Some("Connected to Ollama".to_string())
|
||||
} else {
|
||||
Some("Cannot reach Ollama server".to_string())
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// Ollama is a local service, no account info
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// Ollama doesn't track usage statistics
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let mut req = self.http.get(&url);
|
||||
|
||||
// Add Authorization header if API key is present
|
||||
if let Some(ref key) = self.api_key {
|
||||
req = req.header("Authorization", format!("Bearer {}", key));
|
||||
}
|
||||
|
||||
let resp = req.send().await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
let model_list: OllamaModelList = resp.json().await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Convert Ollama models to ModelInfo
|
||||
let models = model_list.models.into_iter().map(|m| {
|
||||
ModelInfo {
|
||||
id: m.name.clone(),
|
||||
display_name: Some(m.name.clone()),
|
||||
description: m.details.as_ref()
|
||||
.and_then(|d| d.family.as_ref())
|
||||
.map(|f| format!("{} model", f)),
|
||||
context_window: None, // Ollama doesn't provide this in list
|
||||
max_output_tokens: None,
|
||||
supports_tools: true, // Most Ollama models support tools
|
||||
supports_vision: false, // Would need to check model capabilities
|
||||
input_price_per_mtok: None, // Local models are free
|
||||
output_price_per_mtok: None,
|
||||
}
|
||||
}).collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
pub mod client;
|
||||
pub mod types;
|
||||
|
||||
pub use client::{OllamaClient, OllamaOptions};
|
||||
pub use client::{OllamaClient, OllamaOptions, OllamaError};
|
||||
pub use types::{ChatMessage, ChatResponseChunk, Tool, ToolCall, ToolFunction, ToolParameters, FunctionCall};
|
||||
|
||||
// Re-export llm-core traits and types for convenience
|
||||
pub use llm_core::{
|
||||
LlmProvider, ProviderInfo, LlmError,
|
||||
ChatOptions, StreamChunk, ToolCallDelta, Usage,
|
||||
ProviderStatus, AccountInfo, UsageStats, ModelInfo,
|
||||
Role,
|
||||
};
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use llm_core::{StreamChunk, ToolCallDelta};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
@@ -63,3 +64,67 @@ pub struct ChunkMessage {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions to/from llm-core types
|
||||
// ============================================================================
|
||||
|
||||
/// Convert from llm_core::ChatMessage to Ollama's ChatMessage
|
||||
impl From<&llm_core::ChatMessage> for ChatMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
let role = msg.role.as_str().to_string();
|
||||
|
||||
// Convert tool_calls if present
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().map(|tc| ToolCall {
|
||||
id: Some(tc.id.clone()),
|
||||
call_type: Some(tc.call_type.clone()),
|
||||
function: FunctionCall {
|
||||
name: tc.function.name.clone(),
|
||||
arguments: tc.function.arguments.clone(),
|
||||
},
|
||||
}).collect()
|
||||
});
|
||||
|
||||
ChatMessage {
|
||||
role,
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert from Ollama's ChatResponseChunk to llm_core::StreamChunk
|
||||
impl From<ChatResponseChunk> for StreamChunk {
|
||||
fn from(chunk: ChatResponseChunk) -> Self {
|
||||
let done = chunk.done.unwrap_or(false);
|
||||
let content = chunk.message.as_ref().and_then(|m| m.content.clone());
|
||||
|
||||
// Convert tool calls to deltas
|
||||
let tool_calls = chunk.message.as_ref().and_then(|m| {
|
||||
m.tool_calls.as_ref().map(|calls| {
|
||||
calls.iter().enumerate().map(|(index, tc)| {
|
||||
// Serialize arguments back to JSON string for delta
|
||||
let arguments_delta = serde_json::to_string(&tc.function.arguments).ok();
|
||||
|
||||
ToolCallDelta {
|
||||
index,
|
||||
id: tc.id.clone(),
|
||||
function_name: Some(tc.function.name.clone()),
|
||||
arguments_delta,
|
||||
}
|
||||
}).collect()
|
||||
})
|
||||
});
|
||||
|
||||
// Ollama doesn't provide per-chunk usage stats, only in final chunk
|
||||
let usage = None;
|
||||
|
||||
StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
18
crates/llm/openai/Cargo.toml
Normal file
18
crates/llm/openai/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-openai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "OpenAI GPT API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time", "io-util"] }
|
||||
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tracing = "0.1"
|
||||
285
crates/llm/openai/src/auth.rs
Normal file
285
crates/llm/openai/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with OpenAI without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for OpenAI device flow
|
||||
pub struct OpenAIAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// OpenAI OAuth endpoints
|
||||
const AUTH_BASE_URL: &str = "https://auth.openai.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl OpenAIAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OpenAIAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for OpenAIAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api.read api.write",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending),
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &OpenAIAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
561
crates/llm/openai/src/client.rs
Normal file
561
crates/llm/openai/src/client.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the Chat Completions API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
|
||||
ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_stream::wrappers::LinesStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
const CHAT_ENDPOINT: &str = "/chat/completions";
|
||||
const MODELS_ENDPOINT: &str = "/models";
|
||||
|
||||
/// OpenAI GPT API client
|
||||
pub struct OpenAIClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI format
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
|
||||
messages.iter().map(OpenAIMessage::from).collect()
|
||||
}
|
||||
|
||||
/// Convert tools to OpenAI format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
|
||||
tools.map(|t| t.iter().map(OpenAITool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAIClient {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Try to parse as error response
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
let byte_stream = response
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
|
||||
let reader = StreamReader::new(byte_stream);
|
||||
let buf_reader = tokio::io::BufReader::new(reader);
|
||||
let lines_stream = LinesStream::new(buf_reader.lines());
|
||||
|
||||
let chunk_stream = lines_stream.filter_map(|line_result| async move {
|
||||
match line_result {
|
||||
Ok(line) => parse_sse_line(&line),
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(chunk_stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Extract the first choice
|
||||
let choice = api_response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| LlmError::Api {
|
||||
message: "No choices in response".to_string(),
|
||||
code: None,
|
||||
})?;
|
||||
|
||||
let content = choice.message.content.clone();
|
||||
|
||||
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let arguments: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||
|
||||
ToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content,
|
||||
tool_calls,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single SSE line into a StreamChunk
|
||||
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SSE format: "data: <json>"
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// OpenAI sends [DONE] to signal end
|
||||
if data == "[DONE]" {
|
||||
return Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Parse the JSON chunk
|
||||
match serde_json::from_str::<ChatCompletionChunk>(data) {
|
||||
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE chunk: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI chunk to our common format
|
||||
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
|
||||
let choice = chunk.choices.first();
|
||||
|
||||
if let Some(choice) = choice {
|
||||
let content = choice.delta.content.clone();
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
|
||||
deltas
|
||||
.iter()
|
||||
.map(|delta| ToolCallDelta {
|
||||
index: delta.index,
|
||||
id: delta.id.clone(),
|
||||
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
|
||||
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let done = choice.finish_reason.is_some();
|
||||
|
||||
Ok(StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
// No choices, treat as done
|
||||
Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known GPT models with their specifications
|
||||
fn get_gpt_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "gpt-4o".to_string(),
|
||||
display_name: Some("GPT-4o".to_string()),
|
||||
description: Some("Most advanced multimodal model with vision".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(2.50),
|
||||
output_price_per_mtok: Some(10.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4o-mini".to_string(),
|
||||
display_name: Some("GPT-4o mini".to_string()),
|
||||
description: Some("Affordable and fast model for simple tasks".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.15),
|
||||
output_price_per_mtok: Some(0.60),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4-turbo".to_string(),
|
||||
display_name: Some("GPT-4 Turbo".to_string()),
|
||||
description: Some("Previous generation high-performance model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(10.0),
|
||||
output_price_per_mtok: Some(30.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
display_name: Some("GPT-3.5 Turbo".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(16_385),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: false,
|
||||
input_price_per_mtok: Some(0.50),
|
||||
output_price_per_mtok: Some(1.50),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1".to_string(),
|
||||
display_name: Some("OpenAI o1".to_string()),
|
||||
description: Some("Reasoning model optimized for complex problems".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(100_000),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(60.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1-mini".to_string(),
|
||||
display_name: Some("OpenAI o1-mini".to_string()),
|
||||
description: Some("Faster reasoning model for STEM".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(65_536),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(12.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OpenAIClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API by listing models
|
||||
let reachable = if authenticated {
|
||||
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
|
||||
let bearer = self.auth.bearer_token().unwrap();
|
||||
|
||||
match self
|
||||
.http
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach OpenAI API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "openai".to_string(),
|
||||
authenticated,
|
||||
account: None, // OpenAI doesn't expose account info via API
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// OpenAI doesn't have a public account info endpoint
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// OpenAI doesn't expose usage stats via the standard API
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// We can optionally fetch from API, but return known models for now
|
||||
Ok(get_gpt_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_gpt_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let openai_msgs = OpenAIClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(openai_msgs.len(), 3);
|
||||
assert_eq!(openai_msgs[0].role, "system");
|
||||
assert_eq!(openai_msgs[1].role, "user");
|
||||
assert_eq!(openai_msgs[2].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(openai_tools.len(), 1);
|
||||
assert_eq!(openai_tools[0].function.name, "read_file");
|
||||
assert_eq!(
|
||||
openai_tools[0].function.description,
|
||||
"Read a file's contents"
|
||||
);
|
||||
}
|
||||
}
|
||||
12
crates/llm/openai/src/lib.rs
Normal file
12
crates/llm/openai/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for OpenAI's GPT models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
285
crates/llm/openai/src/types.rs
Normal file
285
crates/llm/openai/src/types.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatCompletionRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<OpenAITool>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<&'a str>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: OpenAIFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAITool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: OpenAIFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: OpenAIMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Delta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<DeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaToolCall {
|
||||
pub index: usize,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<DeltaFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Models List Response
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for OpenAITool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: OpenAIFunction {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
parameters: FunctionParameters {
|
||||
param_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for OpenAIMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
|
||||
// Handle tool result messages
|
||||
if msg.role == Role::Tool {
|
||||
return Self {
|
||||
role: "tool".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: msg.tool_call_id.clone(),
|
||||
name: msg.name.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| OpenAIToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: serde_json::to_string(&call.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
};
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user