feat(v2): complete multi-LLM providers, TUI redesign, and advanced agent features
Multi-LLM Provider Support: - Add llm-core crate with LlmProvider trait abstraction - Implement Anthropic Claude API client with streaming - Implement OpenAI API client with streaming - Add token counting with SimpleTokenCounter and ClaudeTokenCounter - Add retry logic with exponential backoff and jitter Borderless TUI Redesign: - Rewrite theme system with terminal capability detection (Full/Unicode256/Basic) - Add provider tabs component with keybind switching [1]/[2]/[3] - Implement vim-modal input (Normal/Insert/Visual/Command modes) - Redesign chat panel with timestamps and streaming indicators - Add multi-provider status bar with cost tracking - Add Nerd Font icons with graceful ASCII fallbacks - Add syntax highlighting (syntect) and markdown rendering (pulldown-cmark) Advanced Agent Features: - Add system prompt builder with configurable components - Enhance subagent orchestration with parallel execution - Add git integration module for safe command detection - Add streaming tool results via channels - Expand tool set: AskUserQuestion, TodoWrite, LS, MultiEdit, BashOutput, KillShell - Add WebSearch with provider abstraction Plugin System Enhancement: - Add full agent definition parsing from YAML frontmatter - Add skill system with progressive disclosure - Wire plugin hooks into HookManager 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
18
crates/llm/openai/Cargo.toml
Normal file
18
crates/llm/openai/Cargo.toml
Normal file
@@ -0,0 +1,18 @@
|
||||
[package]
|
||||
name = "llm-openai"
|
||||
version = "0.1.0"
|
||||
edition.workspace = true
|
||||
license.workspace = true
|
||||
description = "OpenAI GPT API client for Owlen"
|
||||
|
||||
[dependencies]
|
||||
llm-core = { path = "../core" }
|
||||
async-trait = "0.1"
|
||||
futures = "0.3"
|
||||
reqwest = { version = "0.12", features = ["json", "stream"] }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
tokio = { version = "1", features = ["sync", "time", "io-util"] }
|
||||
tokio-stream = { version = "0.1", default-features = false, features = ["io-util"] }
|
||||
tokio-util = { version = "0.7", features = ["codec", "io"] }
|
||||
tracing = "0.1"
|
||||
285
crates/llm/openai/src/auth.rs
Normal file
285
crates/llm/openai/src/auth.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI OAuth Authentication
|
||||
//!
|
||||
//! Implements device code flow for authenticating with OpenAI without API keys.
|
||||
|
||||
use llm_core::{AuthMethod, DeviceAuthResult, DeviceCodeResponse, LlmError, OAuthProvider};
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// OAuth client for OpenAI device flow
|
||||
pub struct OpenAIAuth {
|
||||
http: Client,
|
||||
client_id: String,
|
||||
}
|
||||
|
||||
// OpenAI OAuth endpoints
|
||||
const AUTH_BASE_URL: &str = "https://auth.openai.com";
|
||||
const DEVICE_CODE_ENDPOINT: &str = "/oauth/device/code";
|
||||
const TOKEN_ENDPOINT: &str = "/oauth/token";
|
||||
|
||||
// Default client ID for Owlen CLI
|
||||
const DEFAULT_CLIENT_ID: &str = "owlen-cli";
|
||||
|
||||
impl OpenAIAuth {
|
||||
/// Create a new OAuth client with the default CLI client ID
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: DEFAULT_CLIENT_ID.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create with a custom client ID
|
||||
pub fn with_client_id(client_id: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
client_id: client_id.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for OpenAIAuth {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct DeviceCodeRequest<'a> {
|
||||
client_id: &'a str,
|
||||
scope: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeApiResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
verification_uri_complete: Option<String>,
|
||||
expires_in: u64,
|
||||
interval: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct TokenRequest<'a> {
|
||||
client_id: &'a str,
|
||||
device_code: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenApiResponse {
|
||||
access_token: String,
|
||||
#[allow(dead_code)]
|
||||
token_type: String,
|
||||
expires_in: Option<u64>,
|
||||
refresh_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct TokenErrorResponse {
|
||||
error: String,
|
||||
error_description: Option<String>,
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl OAuthProvider for OpenAIAuth {
|
||||
async fn start_device_auth(&self) -> Result<DeviceCodeResponse, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, DEVICE_CODE_ENDPOINT);
|
||||
|
||||
let request = DeviceCodeRequest {
|
||||
client_id: &self.client_id,
|
||||
scope: "api.read api.write",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!(
|
||||
"Device code request failed ({}): {}",
|
||||
status, text
|
||||
)));
|
||||
}
|
||||
|
||||
let api_response: DeviceCodeApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
Ok(DeviceCodeResponse {
|
||||
device_code: api_response.device_code,
|
||||
user_code: api_response.user_code,
|
||||
verification_uri: api_response.verification_uri,
|
||||
verification_uri_complete: api_response.verification_uri_complete,
|
||||
expires_in: api_response.expires_in,
|
||||
interval: api_response.interval,
|
||||
})
|
||||
}
|
||||
|
||||
async fn poll_device_auth(&self, device_code: &str) -> Result<DeviceAuthResult, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
let request = TokenRequest {
|
||||
client_id: &self.client_id,
|
||||
device_code,
|
||||
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if response.status().is_success() {
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
return Ok(DeviceAuthResult::Success {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_in: token_response.expires_in,
|
||||
});
|
||||
}
|
||||
|
||||
// Parse error response
|
||||
let error_response: TokenErrorResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
match error_response.error.as_str() {
|
||||
"authorization_pending" => Ok(DeviceAuthResult::Pending),
|
||||
"slow_down" => Ok(DeviceAuthResult::Pending),
|
||||
"access_denied" => Ok(DeviceAuthResult::Denied),
|
||||
"expired_token" => Ok(DeviceAuthResult::Expired),
|
||||
_ => Err(LlmError::Auth(format!(
|
||||
"Token request failed: {} - {}",
|
||||
error_response.error,
|
||||
error_response.error_description.unwrap_or_default()
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
async fn refresh_token(&self, refresh_token: &str) -> Result<AuthMethod, LlmError> {
|
||||
let url = format!("{}{}", AUTH_BASE_URL, TOKEN_ENDPOINT);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct RefreshRequest<'a> {
|
||||
client_id: &'a str,
|
||||
refresh_token: &'a str,
|
||||
grant_type: &'a str,
|
||||
}
|
||||
|
||||
let request = RefreshRequest {
|
||||
client_id: &self.client_id,
|
||||
refresh_token,
|
||||
grant_type: "refresh_token",
|
||||
};
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
return Err(LlmError::Auth(format!("Token refresh failed: {}", text)));
|
||||
}
|
||||
|
||||
let token_response: TokenApiResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
let expires_at = token_response.expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
Ok(AuthMethod::OAuth {
|
||||
access_token: token_response.access_token,
|
||||
refresh_token: token_response.refresh_token,
|
||||
expires_at,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper to perform the full device auth flow with polling
|
||||
pub async fn perform_device_auth<F>(
|
||||
auth: &OpenAIAuth,
|
||||
on_code: F,
|
||||
) -> Result<AuthMethod, LlmError>
|
||||
where
|
||||
F: FnOnce(&DeviceCodeResponse),
|
||||
{
|
||||
// Start the device flow
|
||||
let device_code = auth.start_device_auth().await?;
|
||||
|
||||
// Let caller display the code to user
|
||||
on_code(&device_code);
|
||||
|
||||
// Poll for completion
|
||||
let poll_interval = std::time::Duration::from_secs(device_code.interval);
|
||||
let deadline =
|
||||
std::time::Instant::now() + std::time::Duration::from_secs(device_code.expires_in);
|
||||
|
||||
loop {
|
||||
if std::time::Instant::now() > deadline {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
match auth.poll_device_auth(&device_code.device_code).await? {
|
||||
DeviceAuthResult::Success {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_in,
|
||||
} => {
|
||||
let expires_at = expires_in.map(|secs| {
|
||||
std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.map(|d| d.as_secs() + secs)
|
||||
.unwrap_or(0)
|
||||
});
|
||||
|
||||
return Ok(AuthMethod::OAuth {
|
||||
access_token,
|
||||
refresh_token,
|
||||
expires_at,
|
||||
});
|
||||
}
|
||||
DeviceAuthResult::Pending => continue,
|
||||
DeviceAuthResult::Denied => {
|
||||
return Err(LlmError::Auth("Authorization denied by user".to_string()));
|
||||
}
|
||||
DeviceAuthResult::Expired => {
|
||||
return Err(LlmError::Auth("Device code expired".to_string()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
561
crates/llm/openai/src/client.rs
Normal file
561
crates/llm/openai/src/client.rs
Normal file
@@ -0,0 +1,561 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the Chat Completions API with streaming support.
|
||||
|
||||
use crate::types::*;
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use llm_core::{
|
||||
AccountInfo, AuthMethod, ChatMessage, ChatOptions, ChatResponse, ChunkStream, FunctionCall,
|
||||
LlmError, LlmProvider, ModelInfo, ProviderInfo, ProviderStatus, StreamChunk, Tool, ToolCall,
|
||||
ToolCallDelta, Usage, UsageStats,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio_stream::wrappers::LinesStream;
|
||||
use tokio_util::io::StreamReader;
|
||||
|
||||
const API_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
const CHAT_ENDPOINT: &str = "/chat/completions";
|
||||
const MODELS_ENDPOINT: &str = "/models";
|
||||
|
||||
/// OpenAI GPT API client
|
||||
pub struct OpenAIClient {
|
||||
http: Client,
|
||||
auth: AuthMethod,
|
||||
model: String,
|
||||
}
|
||||
|
||||
impl OpenAIClient {
|
||||
/// Create a new client with API key authentication
|
||||
pub fn new(api_key: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::api_key(api_key),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with OAuth token
|
||||
pub fn with_oauth(access_token: impl Into<String>) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth: AuthMethod::oauth(access_token),
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a new client with full AuthMethod
|
||||
pub fn with_auth(auth: AuthMethod) -> Self {
|
||||
Self {
|
||||
http: Client::new(),
|
||||
auth,
|
||||
model: "gpt-4o".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the model to use
|
||||
pub fn with_model(mut self, model: impl Into<String>) -> Self {
|
||||
self.model = model.into();
|
||||
self
|
||||
}
|
||||
|
||||
/// Get current auth method (for token refresh)
|
||||
pub fn auth(&self) -> &AuthMethod {
|
||||
&self.auth
|
||||
}
|
||||
|
||||
/// Update the auth method (after refresh)
|
||||
pub fn set_auth(&mut self, auth: AuthMethod) {
|
||||
self.auth = auth;
|
||||
}
|
||||
|
||||
/// Convert messages to OpenAI format
|
||||
fn prepare_messages(messages: &[ChatMessage]) -> Vec<OpenAIMessage> {
|
||||
messages.iter().map(OpenAIMessage::from).collect()
|
||||
}
|
||||
|
||||
/// Convert tools to OpenAI format
|
||||
fn prepare_tools(tools: Option<&[Tool]>) -> Option<Vec<OpenAITool>> {
|
||||
tools.map(|t| t.iter().map(OpenAITool::from).collect())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LlmProvider for OpenAIClient {
|
||||
fn name(&self) -> &str {
|
||||
"openai"
|
||||
}
|
||||
|
||||
fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
async fn chat_stream(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChunkStream, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Try to parse as error response
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
// Parse SSE stream
|
||||
let byte_stream = response
|
||||
.bytes_stream()
|
||||
.map(|result| result.map_err(std::io::Error::other));
|
||||
|
||||
let reader = StreamReader::new(byte_stream);
|
||||
let buf_reader = tokio::io::BufReader::new(reader);
|
||||
let lines_stream = LinesStream::new(buf_reader.lines());
|
||||
|
||||
let chunk_stream = lines_stream.filter_map(|line_result| async move {
|
||||
match line_result {
|
||||
Ok(line) => parse_sse_line(&line),
|
||||
Err(e) => Some(Err(LlmError::Stream(e.to_string()))),
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Box::pin(chunk_stream))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
options: &ChatOptions,
|
||||
tools: Option<&[Tool]>,
|
||||
) -> Result<ChatResponse, LlmError> {
|
||||
let url = format!("{}{}", API_BASE_URL, CHAT_ENDPOINT);
|
||||
|
||||
let model = if options.model.is_empty() {
|
||||
&self.model
|
||||
} else {
|
||||
&options.model
|
||||
};
|
||||
|
||||
let openai_messages = Self::prepare_messages(messages);
|
||||
let openai_tools = Self::prepare_tools(tools);
|
||||
|
||||
let request = ChatCompletionRequest {
|
||||
model,
|
||||
messages: openai_messages,
|
||||
temperature: options.temperature,
|
||||
max_tokens: options.max_tokens,
|
||||
top_p: options.top_p,
|
||||
stop: options.stop.as_deref(),
|
||||
tools: openai_tools,
|
||||
tool_choice: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let bearer = self
|
||||
.auth
|
||||
.bearer_token()
|
||||
.ok_or_else(|| LlmError::Auth("No authentication configured".to_string()))?;
|
||||
|
||||
let response = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| LlmError::Http(e.to_string()))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let text = response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown error".to_string());
|
||||
|
||||
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
|
||||
return Err(LlmError::RateLimit {
|
||||
retry_after_secs: None,
|
||||
});
|
||||
}
|
||||
|
||||
if let Ok(err_resp) = serde_json::from_str::<ErrorResponse>(&text) {
|
||||
return Err(LlmError::Api {
|
||||
message: err_resp.error.message,
|
||||
code: err_resp.error.code,
|
||||
});
|
||||
}
|
||||
|
||||
return Err(LlmError::Api {
|
||||
message: text,
|
||||
code: Some(status.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let api_response: ChatCompletionResponse = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| LlmError::Json(e.to_string()))?;
|
||||
|
||||
// Extract the first choice
|
||||
let choice = api_response
|
||||
.choices
|
||||
.first()
|
||||
.ok_or_else(|| LlmError::Api {
|
||||
message: "No choices in response".to_string(),
|
||||
code: None,
|
||||
})?;
|
||||
|
||||
let content = choice.message.content.clone();
|
||||
|
||||
let tool_calls = choice.message.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| {
|
||||
let arguments: serde_json::Value =
|
||||
serde_json::from_str(&call.function.arguments).unwrap_or_default();
|
||||
|
||||
ToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments,
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let usage = api_response.usage.map(|u| Usage {
|
||||
prompt_tokens: u.prompt_tokens,
|
||||
completion_tokens: u.completion_tokens,
|
||||
total_tokens: u.total_tokens,
|
||||
});
|
||||
|
||||
Ok(ChatResponse {
|
||||
content,
|
||||
tool_calls,
|
||||
usage,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single SSE line into a StreamChunk
|
||||
fn parse_sse_line(line: &str) -> Option<Result<StreamChunk, LlmError>> {
|
||||
let line = line.trim();
|
||||
|
||||
// Skip empty lines and comments
|
||||
if line.is_empty() || line.starts_with(':') {
|
||||
return None;
|
||||
}
|
||||
|
||||
// SSE format: "data: <json>"
|
||||
if let Some(data) = line.strip_prefix("data: ") {
|
||||
// OpenAI sends [DONE] to signal end
|
||||
if data == "[DONE]" {
|
||||
return Some(Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Parse the JSON chunk
|
||||
match serde_json::from_str::<ChatCompletionChunk>(data) {
|
||||
Ok(chunk) => Some(convert_chunk_to_stream_chunk(chunk)),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse SSE chunk: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert OpenAI chunk to our common format
|
||||
fn convert_chunk_to_stream_chunk(chunk: ChatCompletionChunk) -> Result<StreamChunk, LlmError> {
|
||||
let choice = chunk.choices.first();
|
||||
|
||||
if let Some(choice) = choice {
|
||||
let content = choice.delta.content.clone();
|
||||
|
||||
let tool_calls = choice.delta.tool_calls.as_ref().map(|deltas| {
|
||||
deltas
|
||||
.iter()
|
||||
.map(|delta| ToolCallDelta {
|
||||
index: delta.index,
|
||||
id: delta.id.clone(),
|
||||
function_name: delta.function.as_ref().and_then(|f| f.name.clone()),
|
||||
arguments_delta: delta.function.as_ref().and_then(|f| f.arguments.clone()),
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
let done = choice.finish_reason.is_some();
|
||||
|
||||
Ok(StreamChunk {
|
||||
content,
|
||||
tool_calls,
|
||||
done,
|
||||
usage: None,
|
||||
})
|
||||
} else {
|
||||
// No choices, treat as done
|
||||
Ok(StreamChunk {
|
||||
content: None,
|
||||
tool_calls: None,
|
||||
done: true,
|
||||
usage: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// ProviderInfo Implementation
|
||||
// ============================================================================
|
||||
|
||||
/// Known GPT models with their specifications
|
||||
fn get_gpt_models() -> Vec<ModelInfo> {
|
||||
vec![
|
||||
ModelInfo {
|
||||
id: "gpt-4o".to_string(),
|
||||
display_name: Some("GPT-4o".to_string()),
|
||||
description: Some("Most advanced multimodal model with vision".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(2.50),
|
||||
output_price_per_mtok: Some(10.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4o-mini".to_string(),
|
||||
display_name: Some("GPT-4o mini".to_string()),
|
||||
description: Some("Affordable and fast model for simple tasks".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(16_384),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(0.15),
|
||||
output_price_per_mtok: Some(0.60),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-4-turbo".to_string(),
|
||||
display_name: Some("GPT-4 Turbo".to_string()),
|
||||
description: Some("Previous generation high-performance model".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(10.0),
|
||||
output_price_per_mtok: Some(30.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "gpt-3.5-turbo".to_string(),
|
||||
display_name: Some("GPT-3.5 Turbo".to_string()),
|
||||
description: Some("Fast and affordable for simple tasks".to_string()),
|
||||
context_window: Some(16_385),
|
||||
max_output_tokens: Some(4_096),
|
||||
supports_tools: true,
|
||||
supports_vision: false,
|
||||
input_price_per_mtok: Some(0.50),
|
||||
output_price_per_mtok: Some(1.50),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1".to_string(),
|
||||
display_name: Some("OpenAI o1".to_string()),
|
||||
description: Some("Reasoning model optimized for complex problems".to_string()),
|
||||
context_window: Some(200_000),
|
||||
max_output_tokens: Some(100_000),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(15.0),
|
||||
output_price_per_mtok: Some(60.0),
|
||||
},
|
||||
ModelInfo {
|
||||
id: "o1-mini".to_string(),
|
||||
display_name: Some("OpenAI o1-mini".to_string()),
|
||||
description: Some("Faster reasoning model for STEM".to_string()),
|
||||
context_window: Some(128_000),
|
||||
max_output_tokens: Some(65_536),
|
||||
supports_tools: false,
|
||||
supports_vision: true,
|
||||
input_price_per_mtok: Some(3.0),
|
||||
output_price_per_mtok: Some(12.0),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ProviderInfo for OpenAIClient {
|
||||
async fn status(&self) -> Result<ProviderStatus, LlmError> {
|
||||
let authenticated = self.auth.bearer_token().is_some();
|
||||
|
||||
// Try to reach the API by listing models
|
||||
let reachable = if authenticated {
|
||||
let url = format!("{}{}", API_BASE_URL, MODELS_ENDPOINT);
|
||||
let bearer = self.auth.bearer_token().unwrap();
|
||||
|
||||
match self
|
||||
.http
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", bearer))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(resp) => resp.status().is_success(),
|
||||
Err(_) => false,
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let message = if !authenticated {
|
||||
Some("Not authenticated - set OPENAI_API_KEY or run 'owlen login openai'".to_string())
|
||||
} else if !reachable {
|
||||
Some("Cannot reach OpenAI API".to_string())
|
||||
} else {
|
||||
Some("Connected".to_string())
|
||||
};
|
||||
|
||||
Ok(ProviderStatus {
|
||||
provider: "openai".to_string(),
|
||||
authenticated,
|
||||
account: None, // OpenAI doesn't expose account info via API
|
||||
model: self.model.clone(),
|
||||
endpoint: API_BASE_URL.to_string(),
|
||||
reachable,
|
||||
message,
|
||||
})
|
||||
}
|
||||
|
||||
async fn account_info(&self) -> Result<Option<AccountInfo>, LlmError> {
|
||||
// OpenAI doesn't have a public account info endpoint
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn usage_stats(&self) -> Result<Option<UsageStats>, LlmError> {
|
||||
// OpenAI doesn't expose usage stats via the standard API
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list_models(&self) -> Result<Vec<ModelInfo>, LlmError> {
|
||||
// We can optionally fetch from API, but return known models for now
|
||||
Ok(get_gpt_models())
|
||||
}
|
||||
|
||||
async fn model_info(&self, model_id: &str) -> Result<Option<ModelInfo>, LlmError> {
|
||||
let models = get_gpt_models();
|
||||
Ok(models.into_iter().find(|m| m.id == model_id))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use llm_core::ToolParameters;
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn test_message_conversion() {
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are helpful"),
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
];
|
||||
|
||||
let openai_msgs = OpenAIClient::prepare_messages(&messages);
|
||||
|
||||
assert_eq!(openai_msgs.len(), 3);
|
||||
assert_eq!(openai_msgs[0].role, "system");
|
||||
assert_eq!(openai_msgs[1].role, "user");
|
||||
assert_eq!(openai_msgs[2].role, "assistant");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_conversion() {
|
||||
let tools = vec![Tool::function(
|
||||
"read_file",
|
||||
"Read a file's contents",
|
||||
ToolParameters::object(
|
||||
json!({
|
||||
"path": {
|
||||
"type": "string",
|
||||
"description": "File path"
|
||||
}
|
||||
}),
|
||||
vec!["path".to_string()],
|
||||
),
|
||||
)];
|
||||
|
||||
let openai_tools = OpenAIClient::prepare_tools(Some(&tools)).unwrap();
|
||||
|
||||
assert_eq!(openai_tools.len(), 1);
|
||||
assert_eq!(openai_tools[0].function.name, "read_file");
|
||||
assert_eq!(
|
||||
openai_tools[0].function.description,
|
||||
"Read a file's contents"
|
||||
);
|
||||
}
|
||||
}
|
||||
12
crates/llm/openai/src/lib.rs
Normal file
12
crates/llm/openai/src/lib.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
//! OpenAI GPT API Client
|
||||
//!
|
||||
//! Implements the LlmProvider trait for OpenAI's GPT models.
|
||||
//! Supports both API key authentication and OAuth device flow.
|
||||
|
||||
mod auth;
|
||||
mod client;
|
||||
mod types;
|
||||
|
||||
pub use auth::*;
|
||||
pub use client::*;
|
||||
pub use types::*;
|
||||
285
crates/llm/openai/src/types.rs
Normal file
285
crates/llm/openai/src/types.rs
Normal file
@@ -0,0 +1,285 @@
|
||||
//! OpenAI API request/response types
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
// ============================================================================
|
||||
// Request Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
pub struct ChatCompletionRequest<'a> {
|
||||
pub model: &'a str,
|
||||
pub messages: Vec<OpenAIMessage>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<&'a [String]>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<OpenAITool>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<&'a str>,
|
||||
|
||||
pub stream: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIMessage {
|
||||
pub role: String, // "system", "user", "assistant", "tool"
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<OpenAIToolCall>>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub call_type: String,
|
||||
pub function: OpenAIFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunctionCall {
|
||||
pub name: String,
|
||||
pub arguments: String, // JSON string
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAITool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: String,
|
||||
pub function: OpenAIFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OpenAIFunction {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub param_type: String,
|
||||
pub properties: Value,
|
||||
pub required: Vec<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<Choice>,
|
||||
pub usage: Option<UsageInfo>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Choice {
|
||||
pub index: u32,
|
||||
pub message: OpenAIMessage,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct UsageInfo {
|
||||
pub prompt_tokens: u32,
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Streaming Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChatCompletionChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: String,
|
||||
pub choices: Vec<ChunkChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: Delta,
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Delta {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub role: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<DeltaToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaToolCall {
|
||||
pub index: usize,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename = "type")]
|
||||
pub call_type: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub function: Option<DeltaFunction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct DeltaFunction {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub arguments: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Error Response Types
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ErrorResponse {
|
||||
pub error: ApiError,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ApiError {
|
||||
pub message: String,
|
||||
#[serde(rename = "type")]
|
||||
pub error_type: String,
|
||||
pub code: Option<String>,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Models List Response
|
||||
// ============================================================================
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelsResponse {
|
||||
pub object: String,
|
||||
pub data: Vec<ModelData>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ModelData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Conversions
|
||||
// ============================================================================
|
||||
|
||||
impl From<&llm_core::Tool> for OpenAITool {
|
||||
fn from(tool: &llm_core::Tool) -> Self {
|
||||
Self {
|
||||
tool_type: "function".to_string(),
|
||||
function: OpenAIFunction {
|
||||
name: tool.function.name.clone(),
|
||||
description: tool.function.description.clone(),
|
||||
parameters: FunctionParameters {
|
||||
param_type: tool.function.parameters.param_type.clone(),
|
||||
properties: tool.function.parameters.properties.clone(),
|
||||
required: tool.function.parameters.required.clone(),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&llm_core::ChatMessage> for OpenAIMessage {
|
||||
fn from(msg: &llm_core::ChatMessage) -> Self {
|
||||
use llm_core::Role;
|
||||
|
||||
let role = match msg.role {
|
||||
Role::System => "system",
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
|
||||
// Handle tool result messages
|
||||
if msg.role == Role::Tool {
|
||||
return Self {
|
||||
role: "tool".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: msg.tool_call_id.clone(),
|
||||
name: msg.name.clone(),
|
||||
};
|
||||
}
|
||||
|
||||
// Handle assistant messages with tool calls
|
||||
if msg.role == Role::Assistant && msg.tool_calls.is_some() {
|
||||
let tool_calls = msg.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|call| OpenAIToolCall {
|
||||
id: call.id.clone(),
|
||||
call_type: "function".to_string(),
|
||||
function: OpenAIFunctionCall {
|
||||
name: call.function.name.clone(),
|
||||
arguments: serde_json::to_string(&call.function.arguments)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
return Self {
|
||||
role: "assistant".to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
};
|
||||
}
|
||||
|
||||
// Simple text message
|
||||
Self {
|
||||
role: role.to_string(),
|
||||
content: msg.content.clone(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user