- Added a `tool_output` color to the `Theme` struct. - Updated all built-in themes to include the new color. - Modified the TUI to use the `tool_output` color for rendering tool output.
926 lines
33 KiB
Rust
926 lines
33 KiB
Rust
use crate::config::{Config, McpMode};
|
|
use crate::consent::ConsentManager;
|
|
use crate::conversation::ConversationManager;
|
|
use crate::credentials::CredentialManager;
|
|
use crate::encryption::{self, VaultHandle};
|
|
use crate::formatting::MessageFormatter;
|
|
use crate::input::InputBuffer;
|
|
use crate::mcp::client::{McpClient, RemoteMcpClient};
|
|
use crate::mcp::{LocalMcpClient, McpToolCall};
|
|
use crate::model::ModelManager;
|
|
use crate::provider::{ChatStream, Provider};
|
|
use crate::storage::{SessionMeta, StorageManager};
|
|
use crate::tools::{
|
|
code_exec::CodeExecTool,
|
|
fs_tools::{ResourcesGetTool, ResourcesListTool},
|
|
registry::ToolRegistry,
|
|
web_search::WebSearchTool,
|
|
web_search_detailed::WebSearchDetailedTool,
|
|
Tool,
|
|
};
|
|
use crate::types::{
|
|
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
|
};
|
|
use crate::validation::{get_builtin_schemas, SchemaValidator};
|
|
use crate::{Error, Result};
|
|
use log::warn;
|
|
use std::env;
|
|
use std::path::PathBuf;
|
|
use std::sync::{Arc, Mutex};
|
|
use uuid::Uuid;
|
|
|
|
/// Outcome of submitting a chat request
|
|
pub enum SessionOutcome {
|
|
/// Immediate response received (non-streaming)
|
|
Complete(ChatResponse),
|
|
/// Streaming response where chunks will arrive asynchronously
|
|
Streaming {
|
|
response_id: Uuid,
|
|
stream: ChatStream,
|
|
},
|
|
}
|
|
|
|
/// High-level controller encapsulating session state and provider interactions.
|
|
///
|
|
/// This is the main entry point for managing conversations and interacting with LLM providers.
|
|
///
|
|
/// # Example
|
|
///
|
|
/// ```
|
|
/// use std::sync::Arc;
|
|
/// use owlen_core::config::Config;
|
|
/// use owlen_core::provider::{Provider, ChatStream};
|
|
/// use owlen_core::session::{SessionController, SessionOutcome};
|
|
/// use owlen_core::storage::StorageManager;
|
|
/// use owlen_core::types::{ChatRequest, ChatResponse, ChatParameters, Message, ModelInfo, Role};
|
|
/// use owlen_core::Result;
|
|
///
|
|
/// // Mock provider for the example
|
|
/// struct MockProvider;
|
|
/// #[async_trait::async_trait]
|
|
/// impl Provider for MockProvider {
|
|
/// fn name(&self) -> &str { "mock" }
|
|
/// async fn list_models(&self) -> Result<Vec<ModelInfo>> { Ok(vec![]) }
|
|
/// async fn chat(&self, request: ChatRequest) -> Result<ChatResponse> {
|
|
/// Ok(ChatResponse {
|
|
/// message: Message::assistant("Hello back!".to_string()),
|
|
/// usage: None,
|
|
/// is_streaming: false,
|
|
/// is_final: true,
|
|
/// })
|
|
/// }
|
|
/// async fn chat_stream(&self, request: ChatRequest) -> Result<ChatStream> { unimplemented!() }
|
|
/// async fn health_check(&self) -> Result<()> { Ok(()) }
|
|
/// }
|
|
///
|
|
/// #[tokio::main]
|
|
/// async fn main() {
|
|
/// let provider = Arc::new(MockProvider);
|
|
/// let config = Config::default();
|
|
/// let storage = Arc::new(StorageManager::new().await.unwrap());
|
|
/// let enable_code_tools = false; // Set to true for code client
|
|
/// let mut session_controller = SessionController::new(provider, config, storage, enable_code_tools).unwrap();
|
|
///
|
|
/// // Send a message
|
|
/// let outcome = session_controller.send_message(
|
|
/// "Hello".to_string(),
|
|
/// ChatParameters { stream: false, ..Default::default() }
|
|
/// ).await.unwrap();
|
|
///
|
|
/// // Check the response
|
|
/// if let SessionOutcome::Complete(response) = outcome {
|
|
/// assert_eq!(response.message.content, "Hello back!");
|
|
/// }
|
|
///
|
|
/// // The conversation now contains both messages
|
|
/// let messages = session_controller.conversation().messages.clone();
|
|
/// assert_eq!(messages.len(), 2);
|
|
/// assert_eq!(messages[0].content, "Hello");
|
|
/// assert_eq!(messages[1].content, "Hello back!");
|
|
/// }
|
|
/// ```
|
|
pub struct SessionController {
|
|
provider: Arc<dyn Provider>,
|
|
conversation: ConversationManager,
|
|
model_manager: ModelManager,
|
|
input_buffer: InputBuffer,
|
|
formatter: MessageFormatter,
|
|
config: Config,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
schema_validator: Arc<SchemaValidator>,
|
|
mcp_client: Arc<dyn McpClient>,
|
|
storage: Arc<StorageManager>,
|
|
vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
master_key: Option<Arc<Vec<u8>>>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
enable_code_tools: bool, // Whether to enable code execution tools (code client only)
|
|
}
|
|
|
|
fn build_tools(
|
|
config: &Config,
|
|
enable_code_tools: bool,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
) -> Result<(Arc<ToolRegistry>, Arc<SchemaValidator>)> {
|
|
let mut registry = ToolRegistry::new();
|
|
let mut validator = SchemaValidator::new();
|
|
|
|
for (name, schema) in get_builtin_schemas() {
|
|
if let Err(err) = validator.register_schema(&name, schema) {
|
|
warn!("Failed to register built-in schema {name}: {err}");
|
|
}
|
|
}
|
|
|
|
if config
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "web_search")
|
|
&& config.tools.web_search.enabled
|
|
&& config.privacy.enable_remote_search
|
|
{
|
|
let tool = WebSearchTool::new(
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault.clone(),
|
|
);
|
|
let schema = tool.schema();
|
|
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
|
warn!("Failed to register schema for {}: {err}", tool.name());
|
|
}
|
|
registry.register(tool);
|
|
}
|
|
|
|
// Register web_search_detailed tool (provides snippets)
|
|
if config
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "web_search") // Same permission as web_search
|
|
&& config.tools.web_search.enabled
|
|
&& config.privacy.enable_remote_search
|
|
{
|
|
let tool = WebSearchDetailedTool::new(
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault.clone(),
|
|
);
|
|
let schema = tool.schema();
|
|
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
|
warn!("Failed to register schema for {}: {err}", tool.name());
|
|
}
|
|
registry.register(tool);
|
|
}
|
|
|
|
// Code execution tool - only available in code client
|
|
if enable_code_tools
|
|
&& config
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "code_exec")
|
|
&& config.tools.code_exec.enabled
|
|
{
|
|
let tool = CodeExecTool::new(config.tools.code_exec.allowed_languages.clone());
|
|
let schema = tool.schema();
|
|
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
|
warn!("Failed to register schema for {}: {err}", tool.name());
|
|
}
|
|
registry.register(tool);
|
|
}
|
|
|
|
let resources_list_tool = ResourcesListTool;
|
|
let resources_get_tool = ResourcesGetTool;
|
|
validator.register_schema(resources_list_tool.name(), resources_list_tool.schema())?;
|
|
validator.register_schema(resources_get_tool.name(), resources_get_tool.schema())?;
|
|
registry.register(resources_list_tool);
|
|
registry.register(resources_get_tool);
|
|
|
|
Ok((Arc::new(registry), Arc::new(validator)))
|
|
}
|
|
|
|
impl SessionController {
|
|
/// Create a new controller with the given provider and configuration
|
|
///
|
|
/// # Arguments
|
|
/// * `provider` - The LLM provider to use
|
|
/// * `config` - Application configuration
|
|
/// * `storage` - Storage manager for persistence
|
|
/// * `enable_code_tools` - Whether to enable code execution tools (should only be true for code client)
|
|
pub fn new(
|
|
provider: Arc<dyn Provider>,
|
|
config: Config,
|
|
storage: Arc<StorageManager>,
|
|
enable_code_tools: bool,
|
|
) -> Result<Self> {
|
|
let model = config
|
|
.general
|
|
.default_model
|
|
.clone()
|
|
.unwrap_or_else(|| "ollama/default".to_string());
|
|
|
|
let mut vault_handle: Option<Arc<Mutex<VaultHandle>>> = None;
|
|
let mut master_key: Option<Arc<Vec<u8>>> = None;
|
|
let mut credential_manager: Option<Arc<CredentialManager>> = None;
|
|
|
|
if config.privacy.encrypt_local_data {
|
|
let base_dir = storage
|
|
.database_path()
|
|
.parent()
|
|
.map(|p| p.to_path_buf())
|
|
.or_else(dirs::data_local_dir)
|
|
.unwrap_or_else(|| PathBuf::from("."));
|
|
let secure_path = base_dir.join("encrypted_data.json");
|
|
|
|
let handle = match env::var("OWLEN_MASTER_PASSWORD") {
|
|
Ok(password) if !password.is_empty() => {
|
|
encryption::unlock_with_password(secure_path, &password)?
|
|
}
|
|
_ => encryption::unlock_interactive(secure_path)?,
|
|
};
|
|
|
|
let master = Arc::new(handle.data.master_key.clone());
|
|
master_key = Some(master.clone());
|
|
vault_handle = Some(Arc::new(Mutex::new(handle)));
|
|
credential_manager = Some(Arc::new(CredentialManager::new(storage.clone(), master)));
|
|
}
|
|
|
|
// Load consent manager from vault if available, otherwise create new
|
|
let consent_manager = if let Some(ref vault) = vault_handle {
|
|
Arc::new(Mutex::new(ConsentManager::from_vault(vault)))
|
|
} else {
|
|
Arc::new(Mutex::new(ConsentManager::new()))
|
|
};
|
|
|
|
let conversation =
|
|
ConversationManager::with_history_capacity(model, config.storage.max_saved_sessions);
|
|
let formatter =
|
|
MessageFormatter::new(config.ui.wrap_column as usize, config.ui.show_role_labels)
|
|
.with_preserve_empty(config.ui.word_wrap);
|
|
let input_buffer = InputBuffer::new(
|
|
config.input.history_size,
|
|
config.input.multiline,
|
|
config.input.tab_width,
|
|
);
|
|
|
|
let model_manager = ModelManager::new(config.general.model_cache_ttl());
|
|
|
|
let (tool_registry, schema_validator) = build_tools(
|
|
&config,
|
|
enable_code_tools,
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault_handle.clone(),
|
|
)?;
|
|
|
|
let mcp_client: Arc<dyn McpClient> = match config.mcp.mode {
|
|
McpMode::Legacy => Arc::new(LocalMcpClient::new(
|
|
tool_registry.clone(),
|
|
schema_validator.clone(),
|
|
)),
|
|
McpMode::Enabled => Arc::new(RemoteMcpClient::new()?),
|
|
};
|
|
|
|
let controller = Self {
|
|
provider,
|
|
conversation,
|
|
model_manager,
|
|
input_buffer,
|
|
formatter,
|
|
config,
|
|
consent_manager,
|
|
tool_registry,
|
|
schema_validator,
|
|
mcp_client,
|
|
storage,
|
|
vault: vault_handle,
|
|
master_key,
|
|
credential_manager,
|
|
enable_code_tools,
|
|
};
|
|
|
|
Ok(controller)
|
|
}
|
|
|
|
/// Access the active conversation
|
|
pub fn conversation(&self) -> &Conversation {
|
|
self.conversation.active()
|
|
}
|
|
|
|
/// Mutable access to the conversation manager
|
|
pub fn conversation_mut(&mut self) -> &mut ConversationManager {
|
|
&mut self.conversation
|
|
}
|
|
|
|
/// Access input buffer
|
|
pub fn input_buffer(&self) -> &InputBuffer {
|
|
&self.input_buffer
|
|
}
|
|
|
|
/// Mutable input buffer access
|
|
pub fn input_buffer_mut(&mut self) -> &mut InputBuffer {
|
|
&mut self.input_buffer
|
|
}
|
|
|
|
/// Formatter for rendering messages
|
|
pub fn formatter(&self) -> &MessageFormatter {
|
|
&self.formatter
|
|
}
|
|
|
|
/// Update the wrap width of the message formatter
|
|
pub fn set_formatter_wrap_width(&mut self, width: usize) {
|
|
self.formatter.set_wrap_width(width);
|
|
}
|
|
|
|
/// Access configuration
|
|
pub fn config(&self) -> &Config {
|
|
&self.config
|
|
}
|
|
|
|
/// Mutable configuration access
|
|
pub fn config_mut(&mut self) -> &mut Config {
|
|
&mut self.config
|
|
}
|
|
|
|
/// Grant consent programmatically for a tool (for TUI consent dialog)
|
|
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
consent.grant_consent(tool_name, data_types, endpoints);
|
|
|
|
// Persist to vault if available
|
|
if let Some(vault) = &self.vault {
|
|
if let Err(e) = consent.persist_to_vault(vault) {
|
|
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Check if consent is needed for tool calls (non-blocking check)
|
|
/// Returns a list of (tool_name, data_types, endpoints) tuples for tools that need consent
|
|
pub fn check_tools_consent_needed(
|
|
&self,
|
|
tool_calls: &[ToolCall],
|
|
) -> Vec<(String, Vec<String>, Vec<String>)> {
|
|
let consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
let mut needs_consent = Vec::new();
|
|
let mut seen_tools = std::collections::HashSet::new();
|
|
|
|
for tool_call in tool_calls {
|
|
// Skip if we already checked this tool
|
|
if seen_tools.contains(&tool_call.name) {
|
|
continue;
|
|
}
|
|
seen_tools.insert(tool_call.name.clone());
|
|
|
|
// Get tool metadata (data types and endpoints) based on tool name
|
|
let (data_types, endpoints) = match tool_call.name.as_str() {
|
|
"web_search" | "web_search_detailed" => (
|
|
vec!["search query".to_string()],
|
|
vec!["duckduckgo.com".to_string()],
|
|
),
|
|
"code_exec" => (
|
|
vec!["code to execute".to_string()],
|
|
vec!["local sandbox".to_string()],
|
|
),
|
|
_ => (vec![], vec![]),
|
|
};
|
|
|
|
if let Some((tool_name, dt, ep)) =
|
|
consent.check_if_consent_needed(&tool_call.name, data_types, endpoints)
|
|
{
|
|
needs_consent.push((tool_name, dt, ep));
|
|
}
|
|
}
|
|
|
|
needs_consent
|
|
}
|
|
|
|
/// Persist the active conversation to storage
|
|
pub async fn save_active_session(
|
|
&self,
|
|
name: Option<String>,
|
|
description: Option<String>,
|
|
) -> Result<Uuid> {
|
|
self.conversation
|
|
.save_active_with_description(&self.storage, name, description)
|
|
.await
|
|
}
|
|
|
|
/// Persist the active conversation without description override
|
|
pub async fn save_active_session_simple(&self, name: Option<String>) -> Result<Uuid> {
|
|
self.conversation.save_active(&self.storage, name).await
|
|
}
|
|
|
|
/// Load a saved conversation by ID and make it active
|
|
pub async fn load_saved_session(&mut self, id: Uuid) -> Result<()> {
|
|
self.conversation.load_saved(&self.storage, id).await
|
|
}
|
|
|
|
/// Retrieve session metadata from storage
|
|
pub async fn list_saved_sessions(&self) -> Result<Vec<SessionMeta>> {
|
|
ConversationManager::list_saved_sessions(&self.storage).await
|
|
}
|
|
|
|
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
|
self.storage.delete_session(id).await
|
|
}
|
|
|
|
pub async fn clear_secure_data(&self) -> Result<()> {
|
|
self.storage.clear_secure_items().await?;
|
|
if let Some(vault) = &self.vault {
|
|
let mut guard = vault.lock().expect("Vault mutex poisoned");
|
|
guard.data.settings.clear();
|
|
guard.persist()?;
|
|
}
|
|
// Also clear consent records
|
|
{
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
consent.clear_all_consent();
|
|
}
|
|
self.persist_consent()?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Persist current consent state to vault (if encryption is enabled)
|
|
pub fn persist_consent(&self) -> Result<()> {
|
|
if let Some(vault) = &self.vault {
|
|
let consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
consent.persist_to_vault(vault)?;
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
|
match tool {
|
|
"web_search" => {
|
|
self.config.tools.web_search.enabled = enabled;
|
|
self.config.privacy.enable_remote_search = enabled;
|
|
}
|
|
"code_exec" => {
|
|
self.config.tools.code_exec.enabled = enabled;
|
|
}
|
|
other => {
|
|
return Err(Error::InvalidInput(format!("Unknown tool: {other}")));
|
|
}
|
|
}
|
|
|
|
self.rebuild_tools()?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Access the consent manager shared across tools
|
|
pub fn consent_manager(&self) -> Arc<Mutex<ConsentManager>> {
|
|
self.consent_manager.clone()
|
|
}
|
|
|
|
/// Access the tool registry for executing registered tools
|
|
pub fn tool_registry(&self) -> Arc<ToolRegistry> {
|
|
Arc::clone(&self.tool_registry)
|
|
}
|
|
|
|
/// Access the schema validator used for tool input validation
|
|
pub fn schema_validator(&self) -> Arc<SchemaValidator> {
|
|
Arc::clone(&self.schema_validator)
|
|
}
|
|
|
|
/// Construct an MCP server facade for the active tool registry
|
|
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
|
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
|
}
|
|
|
|
/// Access the underlying storage manager
|
|
pub fn storage(&self) -> Arc<StorageManager> {
|
|
Arc::clone(&self.storage)
|
|
}
|
|
|
|
/// Retrieve the active master key if encryption is enabled
|
|
pub fn master_key(&self) -> Option<Arc<Vec<u8>>> {
|
|
self.master_key.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
/// Access the vault handle for managing secure settings
|
|
pub fn vault(&self) -> Option<Arc<Mutex<VaultHandle>>> {
|
|
self.vault.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
/// Access the credential manager if available
|
|
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
|
self.credential_manager.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
pub async fn read_file(&self, path: &str) -> Result<String> {
|
|
let call = McpToolCall {
|
|
name: "resources/get".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
let response = self.mcp_client.call_tool(call).await?;
|
|
let content: String = serde_json::from_value(response.output)?;
|
|
Ok(content)
|
|
}
|
|
|
|
pub async fn list_dir(&self, path: &str) -> Result<Vec<String>> {
|
|
let call = McpToolCall {
|
|
name: "resources/list".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
let response = self.mcp_client.call_tool(call).await?;
|
|
let content: Vec<String> = serde_json::from_value(response.output)?;
|
|
Ok(content)
|
|
}
|
|
|
|
fn rebuild_tools(&mut self) -> Result<()> {
|
|
let (registry, validator) = build_tools(
|
|
&self.config,
|
|
self.enable_code_tools,
|
|
self.consent_manager.clone(),
|
|
self.credential_manager.clone(),
|
|
self.vault.clone(),
|
|
)?;
|
|
self.tool_registry = registry;
|
|
self.schema_validator = validator;
|
|
|
|
self.mcp_client = match self.config.mcp.mode {
|
|
McpMode::Legacy => Arc::new(LocalMcpClient::new(
|
|
self.tool_registry.clone(),
|
|
self.schema_validator.clone(),
|
|
)),
|
|
McpMode::Enabled => Arc::new(RemoteMcpClient::new()?),
|
|
};
|
|
|
|
Ok(())
|
|
}
|
|
|
|
/// Currently selected model identifier
|
|
pub fn selected_model(&self) -> &str {
|
|
&self.conversation.active().model
|
|
}
|
|
|
|
/// Change current model for upcoming requests
|
|
pub fn set_model(&mut self, model: String) {
|
|
self.conversation.set_model(model.clone());
|
|
self.config.general.default_model = Some(model);
|
|
}
|
|
|
|
/// Retrieve cached models, refreshing from provider as needed
|
|
pub async fn models(&self, force_refresh: bool) -> Result<Vec<ModelInfo>> {
|
|
self.model_manager
|
|
.get_or_refresh(force_refresh, || async {
|
|
self.provider.list_models().await
|
|
})
|
|
.await
|
|
}
|
|
|
|
/// Attempt to select the configured default model from cached models
|
|
pub fn ensure_default_model(&mut self, models: &[ModelInfo]) {
|
|
if let Some(default) = self.config.general.default_model.clone() {
|
|
if models.iter().any(|m| m.id == default || m.name == default) {
|
|
self.set_model(default);
|
|
}
|
|
} else if let Some(model) = models.first() {
|
|
self.set_model(model.id.clone());
|
|
}
|
|
}
|
|
|
|
/// Replace the active provider at runtime and invalidate cached model listings
|
|
pub async fn switch_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
|
|
self.provider = provider;
|
|
self.model_manager.invalidate().await;
|
|
Ok(())
|
|
}
|
|
|
|
/// Submit a user message; optionally stream the response
|
|
pub async fn send_message(
|
|
&mut self,
|
|
content: String,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = parameters.stream || self.config.general.enable_streaming;
|
|
parameters.stream = streaming;
|
|
|
|
self.conversation.push_user_message(content);
|
|
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
/// Send a request using the current conversation without adding a new user message
|
|
pub async fn send_request_with_current_conversation(
|
|
&mut self,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = parameters.stream || self.config.general.enable_streaming;
|
|
parameters.stream = streaming;
|
|
|
|
// Get available tools
|
|
let tools = if !self.tool_registry.all().is_empty() {
|
|
Some(
|
|
self.tool_registry
|
|
.all()
|
|
.into_iter()
|
|
.map(|tool| crate::mcp::McpToolDescriptor {
|
|
name: tool.name().to_string(),
|
|
description: tool.description().to_string(),
|
|
input_schema: tool.schema(),
|
|
requires_network: tool.requires_network(),
|
|
requires_filesystem: tool.requires_filesystem(),
|
|
})
|
|
.collect(),
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut request = ChatRequest {
|
|
model: self.conversation.active().model.clone(),
|
|
messages: self.conversation.active().messages.clone(),
|
|
parameters: parameters.clone(),
|
|
tools: tools.clone(),
|
|
};
|
|
|
|
// Tool execution loop (non-streaming only for now)
|
|
if !streaming {
|
|
const MAX_TOOL_ITERATIONS: usize = 5;
|
|
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
|
match self.provider.chat(request.clone()).await {
|
|
Ok(response) => {
|
|
// Check if the response has tool calls
|
|
if response.message.has_tool_calls() {
|
|
// Add assistant's tool call message to conversation
|
|
self.conversation.push_message(response.message.clone());
|
|
|
|
// Execute each tool call
|
|
if let Some(tool_calls) = &response.message.tool_calls {
|
|
for tool_call in tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
|
|
let tool_result =
|
|
self.mcp_client.call_tool(mcp_tool_call).await;
|
|
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| {
|
|
"Tool execution succeeded".to_string()
|
|
}),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
|
|
// Add tool response to conversation
|
|
let tool_msg =
|
|
Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
}
|
|
|
|
// Update request with new messages for next iteration
|
|
request.messages = self.conversation.active().messages.clone();
|
|
continue;
|
|
} else {
|
|
// No more tool calls, return final response
|
|
self.conversation.push_message(response.message.clone());
|
|
return Ok(SessionOutcome::Complete(response));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error: {}", err));
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Max iterations reached
|
|
self.conversation
|
|
.push_assistant_message("Maximum tool execution iterations reached".to_string());
|
|
return Err(crate::Error::Provider(anyhow::anyhow!(
|
|
"Maximum tool execution iterations reached"
|
|
)));
|
|
}
|
|
|
|
// Streaming mode with tool support
|
|
match self.provider.chat_stream(request).await {
|
|
Ok(stream) => {
|
|
let response_id = self.conversation.start_streaming_response();
|
|
Ok(SessionOutcome::Streaming {
|
|
response_id,
|
|
stream,
|
|
})
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error starting stream: {}", err));
|
|
Err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Mark a streaming response message with placeholder content
|
|
pub fn mark_stream_placeholder(&mut self, message_id: Uuid, text: &str) -> Result<()> {
|
|
self.conversation
|
|
.set_stream_placeholder(message_id, text.to_string())
|
|
}
|
|
|
|
/// Apply streaming chunk to the conversation
|
|
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
|
// Check if this chunk contains tool calls
|
|
if chunk.message.has_tool_calls() {
|
|
// This is a tool call chunk - store the tool calls on the message
|
|
self.conversation.set_tool_calls_on_message(
|
|
message_id,
|
|
chunk.message.tool_calls.clone().unwrap_or_default(),
|
|
)?;
|
|
}
|
|
|
|
self.conversation
|
|
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
|
}
|
|
|
|
/// Check if a streaming message has complete tool calls that need execution
|
|
pub fn check_streaming_tool_calls(&self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
|
self.conversation
|
|
.active()
|
|
.messages
|
|
.iter()
|
|
.find(|m| m.id == message_id)
|
|
.and_then(|m| m.tool_calls.clone())
|
|
.filter(|calls| !calls.is_empty())
|
|
}
|
|
|
|
/// Execute tools for a streaming response and continue conversation
|
|
pub async fn execute_streaming_tools(
|
|
&mut self,
|
|
_message_id: Uuid,
|
|
tool_calls: Vec<ToolCall>,
|
|
) -> Result<SessionOutcome> {
|
|
// Execute each tool call
|
|
for tool_call in &tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
let tool_result = self.mcp_client.call_tool(mcp_tool_call).await;
|
|
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| "Tool execution succeeded".to_string()),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
|
|
// Add tool response to conversation
|
|
let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
|
|
// Continue the conversation with tool results
|
|
let parameters = ChatParameters {
|
|
stream: self.config.general.enable_streaming,
|
|
..Default::default()
|
|
};
|
|
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
/// Access conversation history
|
|
pub fn history(&self) -> Vec<Conversation> {
|
|
self.conversation.history().cloned().collect()
|
|
}
|
|
|
|
/// Start a new conversation optionally targeting a specific model
|
|
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
|
self.conversation.start_new(model, name);
|
|
}
|
|
|
|
/// Clear current conversation messages
|
|
pub fn clear(&mut self) {
|
|
self.conversation.clear();
|
|
}
|
|
|
|
/// Generate a short AI description for the current conversation
|
|
pub async fn generate_conversation_description(&self) -> Result<String> {
|
|
let conv = self.conversation.active();
|
|
|
|
// If conversation is empty or very short, return a simple description
|
|
if conv.messages.is_empty() {
|
|
return Ok("Empty conversation".to_string());
|
|
}
|
|
|
|
if conv.messages.len() == 1 {
|
|
let first_msg = &conv.messages[0];
|
|
let preview = first_msg.content.chars().take(50).collect::<String>();
|
|
return Ok(format!(
|
|
"{}{} ",
|
|
preview,
|
|
if first_msg.content.len() > 50 {
|
|
"..."
|
|
} else {
|
|
""
|
|
}
|
|
));
|
|
}
|
|
|
|
// Build a summary prompt from the first few and last few messages
|
|
let mut summary_messages = Vec::new();
|
|
|
|
// Add system message to guide the description
|
|
summary_messages.push(crate::types::Message::system(
|
|
"Summarize this conversation in 1-2 short sentences (max 100 characters). \
|
|
Focus on the main topic or question being discussed. Be concise and descriptive."
|
|
.to_string(),
|
|
));
|
|
|
|
// Include first message
|
|
if let Some(first) = conv.messages.first() {
|
|
summary_messages.push(first.clone());
|
|
}
|
|
|
|
// Include a middle message if conversation is long enough
|
|
if conv.messages.len() > 4 {
|
|
if let Some(mid) = conv.messages.get(conv.messages.len() / 2) {
|
|
summary_messages.push(mid.clone());
|
|
}
|
|
}
|
|
|
|
// Include last message
|
|
if let Some(last) = conv.messages.last() {
|
|
if conv.messages.len() > 1 {
|
|
summary_messages.push(last.clone());
|
|
}
|
|
}
|
|
|
|
// Create a summarization request
|
|
let request = crate::types::ChatRequest {
|
|
model: conv.model.clone(),
|
|
messages: summary_messages,
|
|
parameters: crate::types::ChatParameters {
|
|
temperature: Some(0.3), // Lower temperature for more focused summaries
|
|
max_tokens: Some(50), // Keep it short
|
|
stream: false,
|
|
extra: std::collections::HashMap::new(),
|
|
},
|
|
tools: None,
|
|
};
|
|
|
|
// Get the summary from the provider
|
|
match self.provider.chat(request).await {
|
|
Ok(response) => {
|
|
let description = response.message.content.trim().to_string();
|
|
|
|
// If description is empty, use fallback
|
|
if description.is_empty() {
|
|
let first_msg = &conv.messages[0];
|
|
let preview = first_msg.content.chars().take(50).collect::<String>();
|
|
return Ok(format!(
|
|
"{}{} ",
|
|
preview,
|
|
if first_msg.content.len() > 50 {
|
|
"..."
|
|
} else {
|
|
""
|
|
}
|
|
));
|
|
}
|
|
|
|
// Truncate if too long
|
|
let truncated = if description.len() > 100 {
|
|
description.chars().take(97).collect::<String>()
|
|
// Removed trailing '...' as it's already handled by the format! macro
|
|
} else {
|
|
description
|
|
};
|
|
Ok(truncated)
|
|
}
|
|
Err(_e) => {
|
|
// Fallback to simple description if AI generation fails
|
|
let first_msg = &conv.messages[0];
|
|
let preview = first_msg.content.chars().take(50).collect::<String>();
|
|
Ok(format!(
|
|
"{}{} ",
|
|
preview,
|
|
if first_msg.content.len() > 50 {
|
|
"..."
|
|
} else {
|
|
""
|
|
}
|
|
))
|
|
}
|
|
}
|
|
}
|
|
}
|