feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support: - Add llm-core crate with LlmProvider trait abstraction - Implement Anthropic Claude API client with streaming - Implement OpenAI API client with streaming - Add token counting with SimpleTokenCounter and ClaudeTokenCounter - Add retry logic with exponential backoff and jitter Borderless TUI Redesign: - Rewrite theme system with terminal capability detection (Full/Unicode256/Basic) - Add provider tabs component with keybind switching [1]/[2]/[3] - Implement vim-modal input (Normal/Insert/Visual/Command modes) - Redesign chat panel with timestamps and streaming indicators - Add multi-provider status bar with cost tracking - Add Nerd Font icons with graceful ASCII fallbacks - Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark) Advanced Agent Features: - Add system prompt builder with configurable components - Enhance subagent orchestration with parallel execution - Add git integration module for safe command detection - Add streaming tool results via channels - Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell - Add WebSearch with provider abstraction Plugin System Enhancement: - Add full agent definition parsing from YAML frontmatter - Add skill system with progressive disclosure - Wire plugin hooks into HookManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
18
crates/llm/core/Cargo.toml
Normal file
18
crates/llm/core/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-core"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "LLM provider abstraction layer for Owlen"
|
||||
|
||||
[dependencies]
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
rand = "0.8"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
thiserror = "2.0"
|
||||
tokio = { version = "1.0", features = ["time"] }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { version = "1.0", features = ["macros", "rt"] }
|
||||
195
crates/llm/core/examples/token_counting.rs
Normal file
195
crates/llm/core/examples/token_counting.rs
Normal file
@@ -0,0 +1,195 @@
|
||||
//! Token counting example
|
||||
//!
|
||||
//! This example demonstrates how to use the token counting utilities
|
||||
//! to manage LLM context windows.
|
||||
//!
|
||||
//! Run with: cargo run --example token_counting -p llm-core
|
||||
|
||||
use llm_core::{
|
||||
ChatMessage, ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
println!("=== Token Counting Example ===\n");
|
||||
|
||||
// Example 1: Basic token counting with SimpleTokenCounter
|
||||
println!("1. Basic Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let simple_counter = SimpleTokenCounter::new(8192);
|
||||
let text = "The quick brown fox jumps over the lazy dog.";
|
||||
|
||||
let token_count = simple_counter.count(text);
|
||||
println!("Text: \"{}\"", text);
|
||||
println!("Estimated tokens: {}", token_count);
|
||||
println!("Max context: {}\n", simple_counter.max_context());
|
||||
|
||||
// Example 2: Counting tokens in chat messages
|
||||
println!("2. Counting Tokens in Chat Messages");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant that provides concise answers."),
|
||||
ChatMessage::user("What is the capital of France?"),
|
||||
ChatMessage::assistant("The capital of France is Paris."),
|
||||
ChatMessage::user("What is its population?"),
|
||||
];
|
||||
|
||||
let total_tokens = simple_counter.count_messages(&messages);
|
||||
println!("Number of messages: {}", messages.len());
|
||||
println!("Total tokens (with overhead): {}\n", total_tokens);
|
||||
|
||||
// Example 3: Using ClaudeTokenCounter for Claude models
|
||||
println!("3. Claude-Specific Token Counting");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let claude_counter = ClaudeTokenCounter::new();
|
||||
let claude_total = claude_counter.count_messages(&messages);
|
||||
|
||||
println!("Claude counter max context: {}", claude_counter.max_context());
|
||||
println!("Claude estimated tokens: {}\n", claude_total);
|
||||
|
||||
// Example 4: Context window management
|
||||
println!("4. Context Window Management");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut context = ContextWindow::new(8192);
|
||||
println!("Created context window with max: {} tokens", context.max());
|
||||
|
||||
// Simulate adding messages
|
||||
let conversation = vec![
|
||||
ChatMessage::user("Tell me about Rust programming."),
|
||||
ChatMessage::assistant(
|
||||
"Rust is a systems programming language focused on safety, \
|
||||
speed, and concurrency. It prevents common bugs like null pointer \
|
||||
dereferences and data races through its ownership system.",
|
||||
),
|
||||
ChatMessage::user("What are its main features?"),
|
||||
ChatMessage::assistant(
|
||||
"Rust's main features include: 1) Memory safety without garbage collection, \
|
||||
2) Zero-cost abstractions, 3) Fearless concurrency, 4) Pattern matching, \
|
||||
5) Type inference, and 6) A powerful macro system.",
|
||||
),
|
||||
];
|
||||
|
||||
for (i, msg) in conversation.iter().enumerate() {
|
||||
let tokens = simple_counter.count_messages(&[msg.clone()]);
|
||||
context.add_tokens(tokens);
|
||||
|
||||
let role = msg.role.as_str();
|
||||
let preview = msg
|
||||
.content
|
||||
.as_ref()
|
||||
.map(|c| {
|
||||
if c.len() > 50 {
|
||||
format!("{}...", &c[..50])
|
||||
} else {
|
||||
c.clone()
|
||||
}
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
println!(
|
||||
"Message {}: [{}] \"{}\"",
|
||||
i + 1,
|
||||
role,
|
||||
preview
|
||||
);
|
||||
println!(" Added {} tokens", tokens);
|
||||
println!(" Total used: {} / {}", context.used(), context.max());
|
||||
println!(" Usage: {:.1}%", context.usage_percent() * 100.0);
|
||||
println!(" Progress: {}\n", context.progress_bar(30));
|
||||
}
|
||||
|
||||
// Example 5: Checking context limits
|
||||
println!("5. Checking Context Limits");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
if context.is_near_limit(0.8) {
|
||||
println!("Warning: Context is over 80% full!");
|
||||
} else {
|
||||
println!("Context usage is below 80%");
|
||||
}
|
||||
|
||||
let remaining = context.remaining();
|
||||
println!("Remaining tokens: {}", remaining);
|
||||
|
||||
let new_message_tokens = 500;
|
||||
if context.has_room_for(new_message_tokens) {
|
||||
println!(
|
||||
"Can fit a message of {} tokens",
|
||||
new_message_tokens
|
||||
);
|
||||
} else {
|
||||
println!(
|
||||
"Cannot fit a message of {} tokens - would need to compact or start new context",
|
||||
new_message_tokens
|
||||
);
|
||||
}
|
||||
|
||||
// Example 6: Different counter variants
|
||||
println!("\n6. Using Different Counter Variants");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let counter_8k = SimpleTokenCounter::default_8k();
|
||||
let counter_32k = SimpleTokenCounter::with_32k();
|
||||
let counter_128k = SimpleTokenCounter::with_128k();
|
||||
|
||||
println!("8k context counter: {} tokens", counter_8k.max_context());
|
||||
println!("32k context counter: {} tokens", counter_32k.max_context());
|
||||
println!("128k context counter: {} tokens", counter_128k.max_context());
|
||||
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
|
||||
println!("\nClaude Haiku: {} tokens", haiku.max_context());
|
||||
println!("Claude Sonnet: {} tokens", sonnet.max_context());
|
||||
println!("Claude Opus: {} tokens", opus.max_context());
|
||||
|
||||
// Example 7: Managing context for a long conversation
|
||||
println!("\n7. Long Conversation Simulation");
|
||||
println!("{}", "-".repeat(50));
|
||||
|
||||
let mut long_context = ContextWindow::new(4096); // Smaller context for demo
|
||||
let counter = SimpleTokenCounter::new(4096);
|
||||
|
||||
let mut message_count = 0;
|
||||
let mut compaction_count = 0;
|
||||
|
||||
// Simulate 20 exchanges
|
||||
for i in 0..20 {
|
||||
let user_msg = ChatMessage::user(format!(
|
||||
"This is user message number {} asking a question.",
|
||||
i + 1
|
||||
));
|
||||
let assistant_msg = ChatMessage::assistant(format!(
|
||||
"This is assistant response number {} providing a detailed answer with multiple sentences to make it longer.",
|
||||
i + 1
|
||||
));
|
||||
|
||||
let tokens_needed = counter.count_messages(&[user_msg, assistant_msg]);
|
||||
|
||||
if !long_context.has_room_for(tokens_needed) {
|
||||
println!(
|
||||
"After {} messages, context is full ({}%). Compacting...",
|
||||
message_count,
|
||||
(long_context.usage_percent() * 100.0) as u32
|
||||
);
|
||||
// In a real scenario, we would compact the conversation
|
||||
// For now, just reset
|
||||
long_context.reset();
|
||||
compaction_count += 1;
|
||||
}
|
||||
|
||||
long_context.add_tokens(tokens_needed);
|
||||
message_count += 2;
|
||||
}
|
||||
|
||||
println!("Total messages: {}", message_count);
|
||||
println!("Compactions needed: {}", compaction_count);
|
||||
println!("Final context usage: {:.1}%", long_context.usage_percent() * 100.0);
|
||||
println!("Final progress: {}", long_context.progress_bar(40));
|
||||
|
||||
println!("\n=== Example Complete ===");
|
||||
}
|
||||
796
crates/llm/core/src/lib.rs
Normal file
796
crates/llm/core/src/lib.rs
Normal file
@@ -0,0 +1,796 @@
|
||||
//! LLM Provider Abstraction Layer
|
||||
//!
|
||||
//! This crate defines the common types and traits for LLM provider integration.
|
||||
//! Providers (Ollama, Anthropic Claude, OpenAI) implement the `LlmProvider` trait
|
||||
//! to enable swapping providers at runtime.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::pin::Pin;
|
||||
use thiserror::Error;
|
||||
|
||||
// ============================================================================
|
||||
// Public Modules
|
||||
// ============================================================================
|
||||
|
||||
pub mod retry;
|
||||
pub mod tokens;
|
||||
|
||||
// Re-export token counting types for convenience
|
||||
pub use tokens::{ClaudeTokenCounter, ContextWindow, SimpleTokenCounter, TokenCounter};
|
||||
|
||||
// Re-export retry types for convenience
|
||||
pub use retry::{is_retryable_error, RetryConfig, RetryStrategy};
|
||||
|
||||
// ============================================================================
|
||||
// Error Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LlmError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(String),
|
||||
|
||||
#[error("JSON parsing error: {0}")]
|
||||
Json(String),
|
||||
|
||||
#[error("Authentication error: {0}")]
|
||||
Auth(String),
|
||||
|
||||
#[error("Rate limit exceeded: retry after {retry_after_secs:?} seconds")]
|
||||
RateLimit { retry_after_secs: Option<u64> },
|
||||
|
||||
#[error("API error: {message}")]
|
||||
Api { message: String, code: Option<String> },
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(String),
|
||||
|
||||
#[error("Stream error: {0}")]
|
||||
Stream(String),
|
||||
|
||||
#[error("Request timeout: {0}")]
|
||||
Timeout(String),
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Message Types
|
||||
// ============================================================================
|
||||
|
||||
/// Role of a message in the conversation
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Role {
|
||||
System,
|
||||
User,
|
||||
Assistant,
|
||||
Tool,
|
||||
}
|
||||
|
||||
impl Role {
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Role {
|
||||
fn from(s: &str) -> Self {
|
||||
match s.to_lowercase().as_str() {
|
||||
"system" => Role::System,
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"tool" => Role::Tool,
|
||||
_ => Role::User, // Default fallback
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A message in the conversation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: Role,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls made by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
|
||||
/// For tool role messages: the ID of the tool call this responds to
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
/// For tool role messages: the name of the tool
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
/// Create a system message
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::System,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a user message
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::User,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an assistant message with tool calls (no text content)
|
||||
pub fn assistant_tool_calls(tool_calls: Vec<ToolCall>) -> Self {
|
||||
Self {
|
||||
role: Role::Assistant,
|
||||
content: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a tool result message
|
||||
pub fn tool_result(tool_call_id: impl Into<String>, content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: Role::Tool,
|
||||
content: Some(content.into()),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.into()),
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool Types
|
||||
// ============================================================================
|
||||
|
||||
/// A tool call requested by the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
|
||||
/// The type of tool call (always "function" for now)
|
||||
#[serde(rename = "type", default = "default_function_type")]
|
||||
pub call_type: String,
|
||||
|
||||
/// The function being called
|
||||
pub function: FunctionCall,
|
||||
}
|
||||
|
||||
fn default_function_type() -> String {
|
||||
"function".to_string()
|
||||
}
|
||||
|
||||
/// Details of a function call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct FunctionCall {
|
||||
/// Name of the function to call
|
||||
pub name: String,
|
||||
|
||||
/// Arguments as a JSON object
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Definition of a tool available to the LLM
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
|
||||
impl Tool {
|
||||
/// Create a new function tool
|
||||
pub fn function(
|
||||
name: impl Into<String>,
|
||||
description: impl Into<String>,
|
||||
parameters: ToolParameters,
|
||||
) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: ToolFunction {
|
||||
name: name.into(),
|
||||
description: description.into(),
|
||||
parameters,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Function definition within a tool
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
/// Parameters schema for a function
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
|
||||
/// JSON Schema properties object
|
||||
pub properties: Value,
|
||||
|
||||
/// Required parameter names
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
impl ToolParameters {
|
||||
/// Create an object parameter schema
|
||||
pub fn object(properties: Value, required: Vec<String>) -> Self {
|
||||
Self {
|
||||
param_type: "object".to_string(),
|
||||
properties,
|
||||
required,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
/// A chunk of a streaming response
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
/// Incremental text content
|
||||
pub content: Option<String>,
|
||||
|
||||
/// Tool calls (may be partial/streaming)
|
||||
pub tool_calls: Option<Vec<ToolCallDelta>>,
|
||||
|
||||
/// Whether this is the final chunk
|
||||
pub done: bool,
|
||||
|
||||
/// Usage statistics (typically only in final chunk)
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Partial tool call for streaming
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ToolCallDelta {
|
||||
/// Index of this tool call in the array
|
||||
pub index: usize,
|
||||
|
||||
/// Tool call ID (may only be present in first delta)
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Function name (may only be present in first delta)
|
||||
pub function_name: Option<String>,
|
||||
|
||||
/// Incremental arguments string
|
||||
pub arguments_delta: Option<String>,
|
||||
}
|
||||
|
||||
/// Token usage statistics
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Usage {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Configuration
|
||||
// ============================================================================
|
||||
|
||||
/// Options for a chat request
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChatOptions {
|
||||
/// Model to use
|
||||
pub model: String,
|
||||
|
||||
/// Temperature (0.0 - 2.0)
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
/// Maximum tokens to generate
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
/// Top-p sampling
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
/// Stop sequences
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl ChatOptions {
|
||||
pub fn new(model: impl Into<String>) -> Self {
|
||||
Self {
|
||||
model: model.into(),
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_temperature(mut self, temp: f32) -> Self {
|
||||
self.temperature = Some(temp);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_max_tokens(mut self, max: u32) -> Self {
|
||||
self.max_tokens = Some(max);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Trait
|
||||
// ============================================================================
|
||||
|
||||
/// A boxed stream of chunks
|
||||
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<StreamChunk, LlmError>> + Send>>;
|
||||
|
||||
/// The main trait that all LLM providers must implement
|
||||
#[async_trait]
|
||||
pub trait LlmProvider: Send + Sync {
|
||||
/// Get the provider name (e.g., "ollama", "anthropic", "openai")
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Get the current model name
|
||||
fn model(&self) -> &str;
|
||||
|
||||
/// Send a chat request and receive a streaming response
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The conversation history
|
||||
/// * `options` - Request options (model, temperature, etc.)
|
||||
/// * `tools` - Optional list of tools the model can use
|
||||
///
|
||||
/// # Returns
|
||||
/// A stream of response chunks
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError>;
|
||||
|
||||
/// Send a chat request and receive a complete response (non-streaming)
|
||||
///
|
||||
/// Default implementation collects the stream, but providers may override
|
||||
/// for efficiency.
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
use futures::StreamExt;
|
||||
|
||||
let mut stream = self.chat_stream(messages, options, tools).await?;
|
||||
let mut content = String::new();
|
||||
let mut tool_calls: Vec<PartialToolCall> = Vec::new();
|
||||
let mut usage = None;
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
let chunk = chunk?;
|
||||
|
||||
if let Some(text) = chunk.content {
|
||||
content.push_str(&text);
|
||||
}
|
||||
|
||||
if let Some(deltas) = chunk.tool_calls {
|
||||
for delta in deltas {
|
||||
// Grow the tool_calls vec if needed
|
||||
while tool_calls.len() <= delta.index {
|
||||
tool_calls.push(PartialToolCall::default());
|
||||
}
|
||||
|
||||
let partial = &mut tool_calls[delta.index];
|
||||
if let Some(id) = delta.id {
|
||||
partial.id = Some(id);
|
||||
}
|
||||
if let Some(name) = delta.function_name {
|
||||
partial.function_name = Some(name);
|
||||
}
|
||||
if let Some(args) = delta.arguments_delta {
|
||||
partial.arguments.push_str(&args);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if chunk.usage.is_some() {
|
||||
usage = chunk.usage;
|
||||
}
|
||||
}
|
||||
|
||||
// Convert partial tool calls to complete tool calls
|
||||
let final_tool_calls: Vec<ToolCall> = tool_calls
|
||||
.into_iter()
|
||||
.filter_map(|p| p.try_into_tool_call())
|
||||
.collect();
|
||||
|
||||
Ok(ChatResponse {
|
||||
content: if content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(content)
|
||||
},
|
||||
tool_calls: if final_tool_calls.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(final_tool_calls)
|
||||
},
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// A complete chat response (non-streaming)
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
pub content: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
pub usage: Option<Usage>,
|
||||
}
|
||||
|
||||
/// Helper for accumulating streaming tool calls
|
||||
#[derive(Default)]
|
||||
struct PartialToolCall {
|
||||
id: Option<String>,
|
||||
function_name: Option<String>,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
impl PartialToolCall {
|
||||
fn try_into_tool_call(self) -> Option<ToolCall> {
|
||||
let id = self.id?;
|
||||
let name = self.function_name?;
|
||||
let arguments: Value = serde_json::from_str(&self.arguments).ok()?;
|
||||
|
||||
Some(ToolCall {
|
||||
id,
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall { name, arguments },
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Authentication
|
||||
// ============================================================================
|
||||
|
||||
/// Authentication method for LLM providers
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum AuthMethod {
|
||||
/// No authentication (for local providers like Ollama)
|
||||
None,
|
||||
|
||||
/// API key authentication
|
||||
ApiKey(String),
|
||||
|
||||
/// OAuth access token (from login flow)
|
||||
OAuth {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_at: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
impl AuthMethod {
|
||||
/// Create API key auth
|
||||
pub fn api_key(key: impl Into<String>) -> Self {
|
||||
Self::ApiKey(key.into())
|
||||
}
|
||||
|
||||
/// Create OAuth auth from tokens
|
||||
pub fn oauth(access_token: impl Into<String>) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: None,
|
||||
expires_at: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create OAuth auth with refresh token
|
||||
pub fn oauth_with_refresh(
|
||||
access_token: impl Into<String>,
|
||||
refresh_token: impl Into<String>,
|
||||
expires_at: Option<u64>,
|
||||
) -> Self {
|
||||
Self::OAuth {
|
||||
access_token: access_token.into(),
|
||||
refresh_token: Some(refresh_token.into()),
|
||||
expires_at,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bearer token for Authorization header
|
||||
pub fn bearer_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::None => None,
|
||||
Self::ApiKey(key) => Some(key),
|
||||
Self::OAuth { access_token, .. } => Some(access_token),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if token might need refresh
|
||||
pub fn needs_refresh(&self) -> bool {
|
||||
match self {
|
||||
Self::OAuth {
|
||||
expires_at: Some(exp),
|
||||
refresh_token: Some(_),
|
||||
..
|
||||
} => {
|
||||
let now = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs())
|
||||
.unwrap_or(0);
|
||||
// Refresh if expiring within 5 minutes
|
||||
*exp < now + 300
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Device code response for OAuth device flow
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct DeviceCodeResponse {
|
||||
/// Code the user enters on the verification page
|
||||
pub user_code: String,
|
||||
|
||||
/// URL the user visits to authorize
|
||||
pub verification_uri: String,
|
||||
|
||||
/// Full URL with code pre-filled (if supported)
|
||||
pub verification_uri_complete: Option<String>,
|
||||
|
||||
/// Device code for polling (internal use)
|
||||
pub device_code: String,
|
||||
|
||||
/// How often to poll (in seconds)
|
||||
pub interval: u64,
|
||||
|
||||
/// When the codes expire (in seconds)
|
||||
pub expires_in: u64,
|
||||
}
|
||||
|
||||
/// Result of polling for device authorization
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DeviceAuthResult {
|
||||
/// Still waiting for user to authorize
|
||||
Pending,
|
||||
|
||||
/// User authorized, here are the tokens
|
||||
Success {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
expires_in: Option<u64>,
|
||||
},
|
||||
|
||||
/// User denied authorization
|
||||
Denied,
|
||||
|
||||
/// Code expired
|
||||
Expired,
|
||||
}
|
||||
|
||||
/// Trait for providers that support OAuth device flow
|
||||
#[async_trait]
|
||||
pub trait OAuthProvider {
|
||||
/// Start the device authorization flow
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError>;
|
||||
|
||||
/// Poll for the authorization result
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError>;
|
||||
|
||||
/// Refresh an access token using a refresh token
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError>;
|
||||
}
|
||||
|
||||
/// Stored credentials for a provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StoredCredentials {
|
||||
pub provider: String,
|
||||
pub access_token: String,
|
||||
pub refresh_token: Option<String>,
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Status & Info
|
||||
// ============================================================================
|
||||
|
||||
/// Status information for a provider connection
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProviderStatus {
|
||||
/// Provider name
|
||||
pub provider: String,
|
||||
|
||||
/// Whether the connection is authenticated
|
||||
pub authenticated: bool,
|
||||
|
||||
/// Current user/account info if authenticated
|
||||
pub account: Option<AccountInfo>,
|
||||
|
||||
/// Current model being used
|
||||
pub model: String,
|
||||
|
||||
/// API endpoint URL
|
||||
pub endpoint: String,
|
||||
|
||||
/// Whether the provider is reachable
|
||||
pub reachable: bool,
|
||||
|
||||
/// Any status message or error
|
||||
pub message: Option<String>,
|
||||
}
|
||||
|
||||
/// Account/user information from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AccountInfo {
|
||||
/// Account/user ID
|
||||
pub id: Option<String>,
|
||||
|
||||
/// Display name or email
|
||||
pub name: Option<String>,
|
||||
|
||||
/// Account email
|
||||
pub email: Option<String>,
|
||||
|
||||
/// Account type (free, pro, team, enterprise)
|
||||
pub account_type: Option<String>,
|
||||
|
||||
/// Organization name if applicable
|
||||
pub organization: Option<String>,
|
||||
}
|
||||
|
||||
/// Usage statistics from the provider
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UsageStats {
|
||||
/// Total tokens used in current period
|
||||
pub tokens_used: Option<u64>,
|
||||
|
||||
/// Token limit for current period (if applicable)
|
||||
pub token_limit: Option<u64>,
|
||||
|
||||
/// Number of requests made
|
||||
pub requests_made: Option<u64>,
|
||||
|
||||
/// Request limit (if applicable)
|
||||
pub request_limit: Option<u64>,
|
||||
|
||||
/// Cost incurred (if available)
|
||||
pub cost_usd: Option<f64>,
|
||||
|
||||
/// Period start timestamp
|
||||
pub period_start: Option<u64>,
|
||||
|
||||
/// Period end timestamp
|
||||
pub period_end: Option<u64>,
|
||||
}
|
||||
|
||||
/// Available model information
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ModelInfo {
|
||||
/// Model ID/name
|
||||
pub id: String,
|
||||
|
||||
/// Human-readable display name
|
||||
pub display_name: Option<String>,
|
||||
|
||||
/// Model description
|
||||
pub description: Option<String>,
|
||||
|
||||
/// Context window size (tokens)
|
||||
pub context_window: Option<u32>,
|
||||
|
||||
/// Max output tokens
|
||||
pub max_output_tokens: Option<u32>,
|
||||
|
||||
/// Whether the model supports tool use
|
||||
pub supports_tools: bool,
|
||||
|
||||
/// Whether the model supports vision/images
|
||||
pub supports_vision: bool,
|
||||
|
||||
/// Input token price per 1M tokens (USD)
|
||||
pub input_price_per_mtok: Option<f64>,
|
||||
|
||||
/// Output token price per 1M tokens (USD)
|
||||
pub output_price_per_mtok: Option<f64>,
|
||||
}
|
||||
|
||||
/// Trait for providers that support status/info queries
|
||||
#[async_trait]
|
||||
pub trait ProviderInfo {
|
||||
/// Get the current connection status
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError>;
|
||||
|
||||
/// Get account information (if authenticated)
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError>;
|
||||
|
||||
/// Get usage statistics (if available)
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError>;
|
||||
|
||||
/// List available models
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError>;
|
||||
|
||||
/// Check if a specific model is available
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = self.list_models().await?;
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Provider Factory
|
||||
// ============================================================================
|
||||
|
||||
/// Supported LLM providers
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ProviderType {
|
||||
Ollama,
|
||||
Anthropic,
|
||||
OpenAI,
|
||||
}
|
||||
|
||||
impl ProviderType {
|
||||
pub fn from_str(s: &str) -> Option<Self> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"ollama" => Some(Self::Ollama),
|
||||
"anthropic" | "claude" => Some(Self::Anthropic),
|
||||
"openai" | "gpt" => Some(Self::OpenAI),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "ollama",
|
||||
Self::Anthropic => "anthropic",
|
||||
Self::OpenAI => "openai",
|
||||
}
|
||||
}
|
||||
|
||||
/// Default model for this provider
|
||||
pub fn default_model(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Ollama => "qwen3:8b",
|
||||
Self::Anthropic => "claude-sonnet-4-20250514",
|
||||
Self::OpenAI => "gpt-4o",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ProviderType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.as_str())
|
||||
}
|
||||
}
|
||||
386
crates/llm/core/src/retry.rs
Normal file
386
crates/llm/core/src/retry.rs
Normal file
@@ -0,0 +1,386 @@
|
||||
//! Error recovery and retry logic for LLM operations
|
||||
//!
|
||||
//! This module provides configurable retry strategies with exponential backoff
|
||||
//! for handling transient failures when communicating with LLM providers.
|
||||
|
||||
use crate::LlmError;
|
||||
use rand::Rng;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Configuration for retry behavior
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: u32,
|
||||
/// Initial delay before first retry (in milliseconds)
|
||||
pub initial_delay_ms: u64,
|
||||
/// Maximum delay between retries (in milliseconds)
|
||||
pub max_delay_ms: u64,
|
||||
/// Multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 30000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl RetryConfig {
|
||||
/// Create a new retry configuration with custom values
|
||||
pub fn new(
|
||||
max_retries: u32,
|
||||
initial_delay_ms: u64,
|
||||
max_delay_ms: u64,
|
||||
backoff_multiplier: f32,
|
||||
) -> Self {
|
||||
Self {
|
||||
max_retries,
|
||||
initial_delay_ms,
|
||||
max_delay_ms,
|
||||
backoff_multiplier,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with no retries
|
||||
pub fn no_retry() -> Self {
|
||||
Self {
|
||||
max_retries: 0,
|
||||
initial_delay_ms: 0,
|
||||
max_delay_ms: 0,
|
||||
backoff_multiplier: 1.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a configuration with aggressive retries for rate-limited scenarios
|
||||
pub fn aggressive() -> Self {
|
||||
Self {
|
||||
max_retries: 5,
|
||||
initial_delay_ms: 2000,
|
||||
max_delay_ms: 60000,
|
||||
backoff_multiplier: 2.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Determines whether an error is retryable
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `error` - The error to check
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the error is transient and the operation should be retried,
|
||||
/// `false` if the error is permanent and retrying won't help
|
||||
pub fn is_retryable_error(error: &LlmError) -> bool {
|
||||
match error {
|
||||
// Always retry rate limits
|
||||
LlmError::RateLimit { .. } => true,
|
||||
|
||||
// Always retry timeouts
|
||||
LlmError::Timeout(_) => true,
|
||||
|
||||
// Retry HTTP errors that are server-side (5xx)
|
||||
LlmError::Http(msg) => {
|
||||
// Check if the error message contains a 5xx status code
|
||||
msg.contains("500")
|
||||
|| msg.contains("502")
|
||||
|| msg.contains("503")
|
||||
|| msg.contains("504")
|
||||
|| msg.contains("Internal Server Error")
|
||||
|| msg.contains("Bad Gateway")
|
||||
|| msg.contains("Service Unavailable")
|
||||
|| msg.contains("Gateway Timeout")
|
||||
}
|
||||
|
||||
// Don't retry authentication errors - they need user intervention
|
||||
LlmError::Auth(_) => false,
|
||||
|
||||
// Don't retry JSON parsing errors - the data is malformed
|
||||
LlmError::Json(_) => false,
|
||||
|
||||
// Don't retry API errors - these are typically client-side issues
|
||||
LlmError::Api { .. } => false,
|
||||
|
||||
// Provider errors might be transient, but we conservatively don't retry
|
||||
LlmError::Provider(_) => false,
|
||||
|
||||
// Stream errors are typically not retryable
|
||||
LlmError::Stream(_) => false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Strategy for retrying failed operations with exponential backoff
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RetryStrategy {
|
||||
config: RetryConfig,
|
||||
}
|
||||
|
||||
impl RetryStrategy {
|
||||
/// Create a new retry strategy with the given configuration
|
||||
pub fn new(config: RetryConfig) -> Self {
|
||||
Self { config }
|
||||
}
|
||||
|
||||
/// Create a retry strategy with default configuration
|
||||
pub fn default_config() -> Self {
|
||||
Self::new(RetryConfig::default())
|
||||
}
|
||||
|
||||
/// Execute an async operation with retries
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `operation` - A function that returns a Future producing a Result
|
||||
///
|
||||
/// # Returns
|
||||
/// The result of the operation, or the last error if all retries fail
|
||||
///
|
||||
/// # Example
|
||||
/// ```ignore
|
||||
/// let strategy = RetryStrategy::default_config();
|
||||
/// let result = strategy.execute(|| async {
|
||||
/// // Your LLM API call here
|
||||
/// llm_client.chat(&messages, &options, None).await
|
||||
/// }).await?;
|
||||
/// ```
|
||||
pub async fn execute<F, T, Fut>(&self, operation: F) -> Result<T, LlmError>
|
||||
where
|
||||
F: Fn() -> Fut,
|
||||
Fut: std::future::Future<Output = Result<T, LlmError>>,
|
||||
{
|
||||
let mut attempt = 0;
|
||||
|
||||
loop {
|
||||
// Try the operation
|
||||
match operation().await {
|
||||
Ok(result) => return Ok(result),
|
||||
Err(err) => {
|
||||
// Check if we should retry
|
||||
if !is_retryable_error(&err) {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
attempt += 1;
|
||||
|
||||
// Check if we've exhausted retries
|
||||
if attempt > self.config.max_retries {
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
// Calculate delay with exponential backoff and jitter
|
||||
let delay = self.delay_for_attempt(attempt);
|
||||
|
||||
// Log retry attempt (in a real implementation, you might use tracing)
|
||||
eprintln!(
|
||||
"Retry attempt {}/{} after {:?}",
|
||||
attempt, self.config.max_retries, delay
|
||||
);
|
||||
|
||||
// Sleep before next attempt
|
||||
tokio::time::sleep(delay).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate the delay for a given attempt number with jitter
|
||||
///
|
||||
/// Uses exponential backoff: delay = initial_delay * (backoff_multiplier ^ (attempt - 1))
|
||||
/// Adds random jitter of ±10% to prevent thundering herd problems
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `attempt` - The attempt number (1-indexed)
|
||||
///
|
||||
/// # Returns
|
||||
/// The delay duration to wait before the next retry
|
||||
fn delay_for_attempt(&self, attempt: u32) -> Duration {
|
||||
// Calculate base delay with exponential backoff
|
||||
let base_delay_ms = self.config.initial_delay_ms as f64
|
||||
* self.config.backoff_multiplier.powi((attempt - 1) as i32) as f64;
|
||||
|
||||
// Cap at max_delay_ms
|
||||
let capped_delay_ms = base_delay_ms.min(self.config.max_delay_ms as f64);
|
||||
|
||||
// Add jitter: ±10%
|
||||
let mut rng = rand::thread_rng();
|
||||
let jitter_factor = rng.gen_range(0.9..=1.1);
|
||||
let final_delay_ms = capped_delay_ms * jitter_factor;
|
||||
|
||||
Duration::from_millis(final_delay_ms as u64)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicU32, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_default_retry_config() {
|
||||
let config = RetryConfig::default();
|
||||
assert_eq!(config.max_retries, 3);
|
||||
assert_eq!(config.initial_delay_ms, 1000);
|
||||
assert_eq!(config.max_delay_ms, 30000);
|
||||
assert_eq!(config.backoff_multiplier, 2.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_no_retry_config() {
|
||||
let config = RetryConfig::no_retry();
|
||||
assert_eq!(config.max_retries, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_retryable_error() {
|
||||
// Retryable errors
|
||||
assert!(is_retryable_error(&LlmError::RateLimit {
|
||||
retry_after_secs: Some(60)
|
||||
}));
|
||||
assert!(is_retryable_error(&LlmError::Timeout(
|
||||
"Request timed out".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"500 Internal Server Error".to_string()
|
||||
)));
|
||||
assert!(is_retryable_error(&LlmError::Http(
|
||||
"503 Service Unavailable".to_string()
|
||||
)));
|
||||
|
||||
// Non-retryable errors
|
||||
assert!(!is_retryable_error(&LlmError::Auth(
|
||||
"Invalid API key".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Json(
|
||||
"Invalid JSON".to_string()
|
||||
)));
|
||||
assert!(!is_retryable_error(&LlmError::Api {
|
||||
message: "Invalid request".to_string(),
|
||||
code: Some("400".to_string())
|
||||
}));
|
||||
assert!(!is_retryable_error(&LlmError::Http(
|
||||
"400 Bad Request".to_string()
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_calculation() {
|
||||
let config = RetryConfig::default();
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Test that delays increase exponentially
|
||||
let delay1 = strategy.delay_for_attempt(1);
|
||||
let delay2 = strategy.delay_for_attempt(2);
|
||||
let delay3 = strategy.delay_for_attempt(3);
|
||||
|
||||
// Base delays should be around 1000ms, 2000ms, 4000ms (with jitter)
|
||||
assert!(delay1.as_millis() >= 900 && delay1.as_millis() <= 1100);
|
||||
assert!(delay2.as_millis() >= 1800 && delay2.as_millis() <= 2200);
|
||||
assert!(delay3.as_millis() >= 3600 && delay3.as_millis() <= 4400);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delay_max_cap() {
|
||||
let config = RetryConfig {
|
||||
max_retries: 10,
|
||||
initial_delay_ms: 1000,
|
||||
max_delay_ms: 5000,
|
||||
backoff_multiplier: 2.0,
|
||||
};
|
||||
let strategy = RetryStrategy::new(config);
|
||||
|
||||
// Even with high attempt numbers, delay should be capped
|
||||
let delay = strategy.delay_for_attempt(10);
|
||||
assert!(delay.as_millis() <= 5500); // max + jitter
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_on_first_attempt() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Ok::<_, LlmError>(42)
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_success_after_retries() {
|
||||
let config = RetryConfig::new(3, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
let current = count.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if current < 3 {
|
||||
Err(LlmError::Timeout("Timeout".to_string()))
|
||||
} else {
|
||||
Ok(42)
|
||||
}
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert_eq!(result.unwrap(), 42);
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_retry_exhausted() {
|
||||
let config = RetryConfig::new(2, 10, 100, 2.0); // Fast retries for testing
|
||||
let strategy = RetryStrategy::new(config);
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Timeout("Always fails".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 3); // Initial attempt + 2 retries
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_non_retryable_error() {
|
||||
let strategy = RetryStrategy::default_config();
|
||||
let call_count = Arc::new(AtomicU32::new(0));
|
||||
let count_clone = call_count.clone();
|
||||
|
||||
let result = strategy
|
||||
.execute(|| {
|
||||
let count = count_clone.clone();
|
||||
async move {
|
||||
count.fetch_add(1, Ordering::SeqCst);
|
||||
Err::<(), _>(LlmError::Auth("Invalid API key".to_string()))
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert_eq!(call_count.load(Ordering::SeqCst), 1); // Should not retry
|
||||
}
|
||||
}
|
||||
607
crates/llm/core/src/tokens.rs
Normal file
607
crates/llm/core/src/tokens.rs
Normal file
@@ -0,0 +1,607 @@
|
||||
//! Token counting utilities for LLM context management
|
||||
//!
|
||||
//! This module provides token counting abstractions and implementations for
|
||||
//! managing LLM context windows. Token counters estimate token usage without
|
||||
//! requiring external tokenization libraries, using heuristic-based approaches.
|
||||
|
||||
use crate::ChatMessage;
|
||||
|
||||
// ============================================================================
|
||||
// TokenCounter Trait
|
||||
// ============================================================================
|
||||
|
||||
/// Trait for counting tokens in text and chat messages
|
||||
///
|
||||
/// Implementations provide model-specific token counting logic to help
|
||||
/// manage context windows and estimate API costs.
|
||||
pub trait TokenCounter: Send + Sync {
|
||||
/// Count tokens in a string
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `text` - The text to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated number of tokens
|
||||
fn count(&self, text: &str) -> usize;
|
||||
|
||||
/// Count tokens in chat messages
|
||||
///
|
||||
/// This accounts for both the message content and the overhead
|
||||
/// from the chat message structure (roles, delimiters, etc.).
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `messages` - The messages to count tokens for
|
||||
///
|
||||
/// # Returns
|
||||
/// Estimated total tokens including message structure overhead
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize;
|
||||
|
||||
/// Get the model's max context window size
|
||||
///
|
||||
/// # Returns
|
||||
/// Maximum number of tokens the model can handle
|
||||
fn max_context(&self) -> usize;
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// SimpleTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// A basic token counter using simple heuristics
|
||||
///
|
||||
/// This counter uses the rule of thumb that English text averages about
|
||||
/// 4 characters per token. It adds overhead for message structure.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let text = "Hello, world!";
|
||||
/// let tokens = counter.count(text);
|
||||
/// assert!(tokens > 0);
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("What is the weather?"),
|
||||
/// ChatMessage::assistant("I don't have access to weather data."),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// assert!(total > 0);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SimpleTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl SimpleTokenCounter {
|
||||
/// Create a new simple token counter
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size for the model
|
||||
pub fn new(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a token counter with a default 8192 token context
|
||||
pub fn default_8k() -> Self {
|
||||
Self::new(8192)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 32k token context
|
||||
pub fn with_32k() -> Self {
|
||||
Self::new(32768)
|
||||
}
|
||||
|
||||
/// Create a token counter with a 128k token context
|
||||
pub fn with_128k() -> Self {
|
||||
Self::new(131072)
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for SimpleTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Estimate: approximately 4 characters per token for English
|
||||
// Add 3 before dividing to round up
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Base overhead for message formatting (estimated)
|
||||
// Each message has role, delimiters, etc.
|
||||
const MESSAGE_OVERHEAD: usize = 4;
|
||||
|
||||
for msg in messages {
|
||||
// Count role
|
||||
total += MESSAGE_OVERHEAD;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls (more expensive due to JSON structure)
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
for tc in tool_calls {
|
||||
// ID overhead
|
||||
total += self.count(&tc.id);
|
||||
// Function name
|
||||
total += self.count(&tc.function.name);
|
||||
// Arguments (JSON serialized, add 20% overhead for JSON structure)
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Count tool call id for tool result messages
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
// Count tool name for tool result messages
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ClaudeTokenCounter
|
||||
// ============================================================================
|
||||
|
||||
/// Token counter optimized for Anthropic Claude models
|
||||
///
|
||||
/// Claude models have specific tokenization characteristics and overhead.
|
||||
/// This counter adjusts the estimates accordingly.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{TokenCounter, ClaudeTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = ClaudeTokenCounter::new();
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::system("You are a helpful assistant."),
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ];
|
||||
/// let total = counter.count_messages(&messages);
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ClaudeTokenCounter {
|
||||
max_context: usize,
|
||||
}
|
||||
|
||||
impl ClaudeTokenCounter {
|
||||
/// Create a new Claude token counter with default 200k context
|
||||
///
|
||||
/// This is suitable for Claude 3.5 Sonnet, Claude 4 Sonnet, and Claude 4 Opus.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
max_context: 200_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a Claude counter with a custom context window
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max_context` - Maximum context window size
|
||||
pub fn with_context(max_context: usize) -> Self {
|
||||
Self { max_context }
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3 Haiku (200k context)
|
||||
pub fn haiku() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 3.5 Sonnet (200k context)
|
||||
pub fn sonnet() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
|
||||
/// Create a counter for Claude 4 Opus (200k context)
|
||||
pub fn opus() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ClaudeTokenCounter {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl TokenCounter for ClaudeTokenCounter {
|
||||
fn count(&self, text: &str) -> usize {
|
||||
// Claude's tokenization is similar to the 4 chars/token heuristic
|
||||
// but tends to be slightly more efficient with structured content
|
||||
(text.len() + 3) / 4
|
||||
}
|
||||
|
||||
fn count_messages(&self, messages: &[ChatMessage]) -> usize {
|
||||
let mut total = 0;
|
||||
|
||||
// Claude has specific message formatting overhead
|
||||
const MESSAGE_OVERHEAD: usize = 5;
|
||||
const SYSTEM_MESSAGE_OVERHEAD: usize = 3;
|
||||
|
||||
for msg in messages {
|
||||
// Different overhead for system vs other messages
|
||||
let overhead = if matches!(msg.role, crate::Role::System) {
|
||||
SYSTEM_MESSAGE_OVERHEAD
|
||||
} else {
|
||||
MESSAGE_OVERHEAD
|
||||
};
|
||||
|
||||
total += overhead;
|
||||
|
||||
// Count content
|
||||
if let Some(content) = &msg.content {
|
||||
total += self.count(content);
|
||||
}
|
||||
|
||||
// Count tool calls
|
||||
if let Some(tool_calls) = &msg.tool_calls {
|
||||
// Claude's tool call format has additional overhead
|
||||
const TOOL_CALL_OVERHEAD: usize = 10;
|
||||
|
||||
for tc in tool_calls {
|
||||
total += TOOL_CALL_OVERHEAD;
|
||||
total += self.count(&tc.id);
|
||||
total += self.count(&tc.function.name);
|
||||
|
||||
// Arguments with JSON structure overhead
|
||||
let args_str = tc.function.arguments.to_string();
|
||||
total += (self.count(&args_str) * 12) / 10;
|
||||
}
|
||||
}
|
||||
|
||||
// Tool result overhead
|
||||
if msg.tool_call_id.is_some() {
|
||||
const TOOL_RESULT_OVERHEAD: usize = 8;
|
||||
total += TOOL_RESULT_OVERHEAD;
|
||||
|
||||
if let Some(tool_call_id) = &msg.tool_call_id {
|
||||
total += self.count(tool_call_id);
|
||||
}
|
||||
|
||||
if let Some(name) = &msg.name {
|
||||
total += self.count(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
total
|
||||
}
|
||||
|
||||
fn max_context(&self) -> usize {
|
||||
self.max_context
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ContextWindow
|
||||
// ============================================================================
|
||||
|
||||
/// Manages context window tracking for a conversation
|
||||
///
|
||||
/// Helps monitor token usage and determine when context limits are approaching.
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use llm_core::tokens::{ContextWindow, TokenCounter, SimpleTokenCounter};
|
||||
/// use llm_core::ChatMessage;
|
||||
///
|
||||
/// let counter = SimpleTokenCounter::new(8192);
|
||||
/// let mut window = ContextWindow::new(counter.max_context());
|
||||
///
|
||||
/// let messages = vec![
|
||||
/// ChatMessage::user("Hello!"),
|
||||
/// ChatMessage::assistant("Hi there!"),
|
||||
/// ];
|
||||
///
|
||||
/// let tokens = counter.count_messages(&messages);
|
||||
/// window.add_tokens(tokens);
|
||||
///
|
||||
/// println!("Used: {} tokens", window.used());
|
||||
/// println!("Remaining: {} tokens", window.remaining());
|
||||
/// println!("Usage: {:.1}%", window.usage_percent() * 100.0);
|
||||
///
|
||||
/// if window.is_near_limit(0.8) {
|
||||
/// println!("Warning: Context is 80% full!");
|
||||
/// }
|
||||
/// ```
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ContextWindow {
|
||||
/// Number of tokens currently used
|
||||
used: usize,
|
||||
/// Maximum number of tokens allowed
|
||||
max: usize,
|
||||
}
|
||||
|
||||
impl ContextWindow {
|
||||
/// Create a new context window tracker
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size in tokens
|
||||
pub fn new(max: usize) -> Self {
|
||||
Self { used: 0, max }
|
||||
}
|
||||
|
||||
/// Create a context window with initial usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `max` - Maximum context window size
|
||||
/// * `used` - Initial number of tokens used
|
||||
pub fn with_usage(max: usize, used: usize) -> Self {
|
||||
Self { used, max }
|
||||
}
|
||||
|
||||
/// Get the number of tokens currently used
|
||||
pub fn used(&self) -> usize {
|
||||
self.used
|
||||
}
|
||||
|
||||
/// Get the maximum number of tokens
|
||||
pub fn max(&self) -> usize {
|
||||
self.max
|
||||
}
|
||||
|
||||
/// Get the number of remaining tokens
|
||||
pub fn remaining(&self) -> usize {
|
||||
self.max.saturating_sub(self.used)
|
||||
}
|
||||
|
||||
/// Get the usage as a percentage (0.0 to 1.0)
|
||||
///
|
||||
/// Returns the fraction of the context window that is currently used.
|
||||
pub fn usage_percent(&self) -> f32 {
|
||||
if self.max == 0 {
|
||||
return 0.0;
|
||||
}
|
||||
self.used as f32 / self.max as f32
|
||||
}
|
||||
|
||||
/// Check if usage is near the limit
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `threshold` - Threshold as a fraction (0.0 to 1.0). For example,
|
||||
/// 0.8 means "is usage > 80%?"
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the current usage exceeds the threshold percentage
|
||||
pub fn is_near_limit(&self, threshold: f32) -> bool {
|
||||
self.usage_percent() > threshold
|
||||
}
|
||||
|
||||
/// Add tokens to the usage count
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens to add
|
||||
pub fn add_tokens(&mut self, tokens: usize) {
|
||||
self.used = self.used.saturating_add(tokens);
|
||||
}
|
||||
|
||||
/// Set the current usage
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `used` - Number of tokens currently used
|
||||
pub fn set_used(&mut self, used: usize) {
|
||||
self.used = used;
|
||||
}
|
||||
|
||||
/// Reset the usage counter to zero
|
||||
pub fn reset(&mut self) {
|
||||
self.used = 0;
|
||||
}
|
||||
|
||||
/// Check if there's enough room for additional tokens
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `tokens` - Number of tokens needed
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if adding these tokens would stay within the limit
|
||||
pub fn has_room_for(&self, tokens: usize) -> bool {
|
||||
self.used.saturating_add(tokens) <= self.max
|
||||
}
|
||||
|
||||
/// Get a visual progress bar representation
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `width` - Width of the progress bar in characters
|
||||
///
|
||||
/// # Returns
|
||||
/// A string with a simple text-based progress bar
|
||||
pub fn progress_bar(&self, width: usize) -> String {
|
||||
if width == 0 {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
let percent = self.usage_percent();
|
||||
let filled = ((percent * width as f32) as usize).min(width);
|
||||
let empty = width - filled;
|
||||
|
||||
format!(
|
||||
"[{}{}] {:.1}%",
|
||||
"=".repeat(filled),
|
||||
" ".repeat(empty),
|
||||
percent * 100.0
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tests
|
||||
// ============================================================================
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{ChatMessage, FunctionCall, ToolCall};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_basic() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
// Empty string
|
||||
assert_eq!(counter.count(""), 0);
|
||||
|
||||
// Short string (~4 chars/token)
|
||||
let text = "Hello, world!"; // 13 chars -> ~4 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 3 && count <= 5);
|
||||
|
||||
// Longer text
|
||||
let text = "The quick brown fox jumps over the lazy dog"; // 44 chars -> ~11 tokens
|
||||
let count = counter.count(text);
|
||||
assert!(count >= 10 && count <= 13);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_messages() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello!"),
|
||||
ChatMessage::assistant("Hi there! How can I help you today?"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
|
||||
// Should be more than just the text due to overhead
|
||||
let text_only = counter.count("Hello!") + counter.count("Hi there! How can I help you today?");
|
||||
assert!(total > text_only);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_with_tool_calls() {
|
||||
let counter = SimpleTokenCounter::new(8192);
|
||||
|
||||
let tool_call = ToolCall {
|
||||
id: "call_123".to_string(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: "read_file".to_string(),
|
||||
arguments: json!({"path": "/etc/hosts"}),
|
||||
},
|
||||
};
|
||||
|
||||
let messages = vec![ChatMessage::assistant_tool_calls(vec![tool_call])];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
assert_eq!(counter.max_context(), 200_000);
|
||||
|
||||
let text = "Hello, Claude!";
|
||||
let count = counter.count(text);
|
||||
assert!(count > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_system_message() {
|
||||
let counter = ClaudeTokenCounter::new();
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant."),
|
||||
ChatMessage::user("Hello!"),
|
||||
];
|
||||
|
||||
let total = counter.count_messages(&messages);
|
||||
assert!(total > 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window() {
|
||||
let mut window = ContextWindow::new(1000);
|
||||
|
||||
assert_eq!(window.used(), 0);
|
||||
assert_eq!(window.max(), 1000);
|
||||
assert_eq!(window.remaining(), 1000);
|
||||
assert_eq!(window.usage_percent(), 0.0);
|
||||
|
||||
window.add_tokens(200);
|
||||
assert_eq!(window.used(), 200);
|
||||
assert_eq!(window.remaining(), 800);
|
||||
assert_eq!(window.usage_percent(), 0.2);
|
||||
|
||||
window.add_tokens(600);
|
||||
assert_eq!(window.used(), 800);
|
||||
assert!(window.is_near_limit(0.7));
|
||||
assert!(!window.is_near_limit(0.9));
|
||||
|
||||
assert!(window.has_room_for(200));
|
||||
assert!(!window.has_room_for(300));
|
||||
|
||||
window.reset();
|
||||
assert_eq!(window.used(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_progress_bar() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
window.add_tokens(50);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("====="));
|
||||
assert!(bar.contains("50.0%"));
|
||||
|
||||
window.add_tokens(40);
|
||||
let bar = window.progress_bar(10);
|
||||
assert!(bar.contains("========="));
|
||||
assert!(bar.contains("90.0%"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_context_window_saturation() {
|
||||
let mut window = ContextWindow::new(100);
|
||||
|
||||
// Adding more tokens than max should saturate, not overflow
|
||||
window.add_tokens(150);
|
||||
assert_eq!(window.used(), 150);
|
||||
assert_eq!(window.remaining(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_simple_counter_constructors() {
|
||||
let counter1 = SimpleTokenCounter::default_8k();
|
||||
assert_eq!(counter1.max_context(), 8192);
|
||||
|
||||
let counter2 = SimpleTokenCounter::with_32k();
|
||||
assert_eq!(counter2.max_context(), 32768);
|
||||
|
||||
let counter3 = SimpleTokenCounter::with_128k();
|
||||
assert_eq!(counter3.max_context(), 131072);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claude_counter_variants() {
|
||||
let haiku = ClaudeTokenCounter::haiku();
|
||||
assert_eq!(haiku.max_context(), 200_000);
|
||||
|
||||
let sonnet = ClaudeTokenCounter::sonnet();
|
||||
assert_eq!(sonnet.max_context(), 200_000);
|
||||
|
||||
let opus = ClaudeTokenCounter::opus();
|
||||
assert_eq!(opus.max_context(), 200_000);
|
||||
|
||||
let custom = ClaudeTokenCounter::with_context(100_000);
|
||||
assert_eq!(custom.max_context(), 100_000);
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user