2431 lines
81 KiB
Rust
2431 lines
81 KiB
Rust
use crate::config::{
|
|
ChatSettings, CompressionStrategy, Config, LEGACY_OLLAMA_CLOUD_API_KEY_ENV,
|
|
LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV, McpResourceConfig, McpServerConfig, OLLAMA_API_KEY_ENV,
|
|
OLLAMA_CLOUD_BASE_URL,
|
|
};
|
|
use crate::consent::{ConsentManager, ConsentScope};
|
|
use crate::conversation::ConversationManager;
|
|
use crate::credentials::CredentialManager;
|
|
use crate::encryption::{self, VaultHandle};
|
|
use crate::formatting::MessageFormatter;
|
|
use crate::input::InputBuffer;
|
|
use crate::llm::ProviderConfig;
|
|
use crate::mcp::McpToolCall;
|
|
use crate::mcp::client::McpClient;
|
|
use crate::mcp::factory::McpClientFactory;
|
|
use crate::mcp::permission::PermissionLayer;
|
|
use crate::mcp::remote_client::{McpRuntimeSecrets, RemoteMcpClient};
|
|
use crate::mode::Mode;
|
|
use crate::model::{DetailedModelInfo, ModelManager};
|
|
use crate::oauth::{DeviceAuthorization, DevicePollState, OAuthClient};
|
|
use crate::providers::OllamaProvider;
|
|
use crate::providers::ollama::normalize_cloud_base_url;
|
|
use crate::storage::{SessionMeta, StorageManager};
|
|
use crate::tools::{WEB_SEARCH_TOOL_NAME, canonical_tool_name, tool_name_matches};
|
|
use crate::types::{
|
|
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, Role, ToolCall,
|
|
};
|
|
use crate::ui::{RoleLabelDisplay, UiController};
|
|
use crate::usage::{UsageLedger, UsageQuota, UsageSnapshot};
|
|
use crate::validation::{SchemaValidator, get_builtin_schemas};
|
|
use crate::{ChatStream, Provider};
|
|
use crate::{
|
|
CodeExecTool, ResourcesDeleteTool, ResourcesGetTool, ResourcesListTool, ResourcesWriteTool,
|
|
ToolRegistry, WebScrapeTool, WebSearchSettings, WebSearchTool,
|
|
};
|
|
use crate::{Error, Result};
|
|
use chrono::{DateTime, Utc};
|
|
use log::{info, warn};
|
|
use reqwest::Url;
|
|
use serde_json::{Value, json};
|
|
use std::cmp::{max, min};
|
|
use std::collections::{HashMap, HashSet};
|
|
use std::env;
|
|
use std::path::PathBuf;
|
|
use std::sync::{Arc, Mutex};
|
|
use std::time::{Duration, SystemTime};
|
|
use tokio::fs;
|
|
use tokio::sync::Mutex as TokioMutex;
|
|
use tokio::sync::mpsc::UnboundedSender;
|
|
use uuid::Uuid;
|
|
|
|
fn env_var_non_empty(name: &str) -> Option<String> {
|
|
env::var(name)
|
|
.ok()
|
|
.map(|value| value.trim().to_string())
|
|
.filter(|value| !value.is_empty())
|
|
}
|
|
|
|
fn estimate_tokens(messages: &[Message]) -> u32 {
|
|
messages
|
|
.iter()
|
|
.map(estimate_message_tokens)
|
|
.fold(0u32, |acc, value| acc.saturating_add(value))
|
|
}
|
|
|
|
fn estimate_message_tokens(message: &Message) -> u32 {
|
|
let content = message.content.trim();
|
|
if content.is_empty() {
|
|
return 4;
|
|
}
|
|
let approx = max(4, content.chars().count() / 4 + 1);
|
|
(approx + 4) as u32
|
|
}
|
|
|
|
fn build_transcript(messages: &[Message]) -> String {
|
|
let mut transcript = String::new();
|
|
let take = min(messages.len(), MAX_TRANSCRIPT_MESSAGES);
|
|
for message in messages.iter().take(take) {
|
|
let role = match message.role {
|
|
Role::User => "User",
|
|
Role::Assistant => "Assistant",
|
|
Role::System => "System",
|
|
Role::Tool => "Tool",
|
|
};
|
|
let snippet = sanitize_snippet(&message.content);
|
|
if snippet.is_empty() {
|
|
continue;
|
|
}
|
|
transcript.push_str(&format!("{role}: {snippet}\n\n"));
|
|
}
|
|
if messages.len() > take {
|
|
transcript.push_str(&format!(
|
|
"... ({} additional messages omitted for brevity)\n",
|
|
messages.len() - take
|
|
));
|
|
}
|
|
transcript
|
|
}
|
|
|
|
fn local_summary(messages: &[Message]) -> String {
|
|
if messages.is_empty() {
|
|
return "(no content to summarize)".to_string();
|
|
}
|
|
let total = messages.len();
|
|
let mut summary = String::from("Summary (local heuristic)\n\n");
|
|
summary.push_str(&format!("- Compressed {total} prior messages.\n"));
|
|
|
|
let recent_users = collect_recent_by_role(messages, Role::User, 3);
|
|
if !recent_users.is_empty() {
|
|
summary.push_str("- Recent user intents:\n");
|
|
for intent in recent_users {
|
|
summary.push_str(&format!(" - {intent}\n"));
|
|
}
|
|
}
|
|
|
|
let recent_assistant = collect_recent_by_role(messages, Role::Assistant, 3);
|
|
if !recent_assistant.is_empty() {
|
|
summary.push_str("- Recent assistant responses:\n");
|
|
for reply in recent_assistant {
|
|
summary.push_str(&format!(" - {reply}\n"));
|
|
}
|
|
}
|
|
|
|
summary.trim_end().to_string()
|
|
}
|
|
|
|
fn collect_recent_by_role(messages: &[Message], role: Role, limit: usize) -> Vec<String> {
|
|
if limit == 0 {
|
|
return Vec::new();
|
|
}
|
|
let mut results = Vec::new();
|
|
for message in messages.iter().rev() {
|
|
if message.role == role {
|
|
let snippet = sanitize_snippet(&message.content);
|
|
if !snippet.is_empty() {
|
|
results.push(snippet);
|
|
if results.len() == limit {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
results.reverse();
|
|
results
|
|
}
|
|
|
|
fn sanitize_snippet(content: &str) -> String {
|
|
let trimmed = content.trim();
|
|
if trimmed.is_empty() {
|
|
return String::new();
|
|
}
|
|
let mut snippet = trimmed.replace('\r', "");
|
|
if snippet.len() > MAX_TRANSCRIPT_MESSAGE_CHARS {
|
|
snippet.truncate(MAX_TRANSCRIPT_MESSAGE_CHARS);
|
|
snippet.push_str("...");
|
|
}
|
|
snippet
|
|
}
|
|
|
|
fn compute_web_search_settings(
|
|
config: &Config,
|
|
provider_id: &str,
|
|
) -> Result<Option<WebSearchSettings>> {
|
|
let provider_id = provider_id.trim();
|
|
let provider_config = match config.providers.get(provider_id) {
|
|
Some(cfg) => cfg,
|
|
None => return Ok(None),
|
|
};
|
|
|
|
if !provider_config.enabled {
|
|
return Ok(None);
|
|
}
|
|
|
|
if provider_config
|
|
.provider_type
|
|
.trim()
|
|
.eq_ignore_ascii_case("ollama")
|
|
{
|
|
// Local Ollama does not expose web search.
|
|
return Ok(None);
|
|
}
|
|
|
|
if !provider_config
|
|
.provider_type
|
|
.trim()
|
|
.eq_ignore_ascii_case("ollama_cloud")
|
|
{
|
|
return Ok(None);
|
|
}
|
|
|
|
let raw_base_url = provider_config
|
|
.base_url
|
|
.as_deref()
|
|
.filter(|value| !value.trim().is_empty());
|
|
let normalized_base_url = normalize_cloud_base_url(raw_base_url).map_err(|err| {
|
|
let display_base = raw_base_url.unwrap_or(OLLAMA_CLOUD_BASE_URL);
|
|
Error::Config(format!(
|
|
"Invalid Ollama Cloud base_url '{}': {err}",
|
|
display_base
|
|
))
|
|
})?;
|
|
|
|
let endpoint = provider_config
|
|
.extra
|
|
.get("web_search_endpoint")
|
|
.and_then(|value| value.as_str())
|
|
.unwrap_or("/api/web_search");
|
|
|
|
let endpoint_url = build_search_url(&normalized_base_url, endpoint)?;
|
|
|
|
let api_key = resolve_web_search_api_key(provider_config)
|
|
.or_else(|| env_var_non_empty(OLLAMA_API_KEY_ENV))
|
|
.or_else(|| env_var_non_empty(LEGACY_OLLAMA_CLOUD_API_KEY_ENV))
|
|
.or_else(|| env_var_non_empty(LEGACY_OWLEN_OLLAMA_CLOUD_API_KEY_ENV));
|
|
|
|
let api_key = match api_key {
|
|
Some(key) if !key.is_empty() => key,
|
|
_ => return Ok(None),
|
|
};
|
|
|
|
let settings = WebSearchSettings {
|
|
endpoint: endpoint_url,
|
|
api_key,
|
|
provider_label: provider_id.to_string(),
|
|
timeout: Duration::from_secs(20),
|
|
};
|
|
|
|
Ok(Some(settings))
|
|
}
|
|
|
|
fn resolve_web_search_api_key(provider_config: &ProviderConfig) -> Option<String> {
|
|
resolve_inline_api_key(provider_config.api_key.as_deref()).or_else(|| {
|
|
provider_config
|
|
.api_key_env
|
|
.as_deref()
|
|
.and_then(|var| env_var_non_empty(var.trim()))
|
|
})
|
|
}
|
|
|
|
fn resolve_inline_api_key(value: Option<&str>) -> Option<String> {
|
|
let raw = value?.trim();
|
|
if raw.is_empty() {
|
|
return None;
|
|
}
|
|
|
|
if let Some(inner) = raw
|
|
.strip_prefix("${")
|
|
.and_then(|value| value.strip_suffix('}'))
|
|
.map(str::trim)
|
|
{
|
|
return env_var_non_empty(inner);
|
|
}
|
|
|
|
if let Some(inner) = raw.strip_prefix('$').map(str::trim) {
|
|
return env_var_non_empty(inner);
|
|
}
|
|
|
|
Some(raw.to_string())
|
|
}
|
|
|
|
fn build_search_url(base_url: &str, endpoint: &str) -> Result<Url> {
|
|
let endpoint = endpoint.trim();
|
|
if let Ok(url) = Url::parse(endpoint) {
|
|
return Ok(url);
|
|
}
|
|
|
|
let trimmed_base = base_url.trim();
|
|
let normalized_base = if trimmed_base.ends_with('/') {
|
|
trimmed_base.to_string()
|
|
} else {
|
|
format!("{}/", trimmed_base)
|
|
};
|
|
|
|
let base = Url::parse(&normalized_base).map_err(|err| {
|
|
Error::Config(format!("Invalid provider base_url '{}': {}", base_url, err))
|
|
})?;
|
|
|
|
if endpoint.is_empty() {
|
|
return Ok(base);
|
|
}
|
|
|
|
base.join(endpoint.trim_start_matches('/')).map_err(|err| {
|
|
Error::Config(format!(
|
|
"Invalid web_search_endpoint '{}': {}",
|
|
endpoint, err
|
|
))
|
|
})
|
|
}
|
|
|
|
pub enum SessionOutcome {
|
|
Complete(ChatResponse),
|
|
Streaming {
|
|
response_id: Uuid,
|
|
stream: ChatStream,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum ControllerEvent {
|
|
ToolRequested {
|
|
request_id: Uuid,
|
|
message_id: Uuid,
|
|
tool_name: String,
|
|
data_types: Vec<String>,
|
|
endpoints: Vec<String>,
|
|
tool_calls: Vec<ToolCall>,
|
|
},
|
|
CompressionCompleted {
|
|
report: CompressionReport,
|
|
},
|
|
}
|
|
|
|
#[derive(Clone, Debug)]
|
|
struct PendingToolRequest {
|
|
message_id: Uuid,
|
|
tool_name: String,
|
|
data_types: Vec<String>,
|
|
endpoints: Vec<String>,
|
|
tool_calls: Vec<ToolCall>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct CompressionReport {
|
|
pub summary_message_id: Uuid,
|
|
pub compressed_messages: usize,
|
|
pub estimated_tokens_before: u32,
|
|
pub estimated_tokens_after: u32,
|
|
pub strategy: CompressionStrategy,
|
|
pub model_used: String,
|
|
pub retained_recent: usize,
|
|
pub automated: bool,
|
|
pub timestamp: DateTime<Utc>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
struct CompressionOptions {
|
|
trigger_tokens: u32,
|
|
retain_recent: usize,
|
|
strategy: CompressionStrategy,
|
|
model_override: Option<String>,
|
|
}
|
|
|
|
impl CompressionOptions {
|
|
fn from_settings(settings: &ChatSettings) -> Self {
|
|
Self {
|
|
trigger_tokens: settings.trigger_tokens.max(64),
|
|
retain_recent: settings.retain_recent_messages.max(2),
|
|
strategy: settings.strategy,
|
|
model_override: settings.model_override.clone(),
|
|
}
|
|
}
|
|
|
|
fn min_chunk_messages(&self) -> usize {
|
|
self.retain_recent.saturating_add(2).max(4)
|
|
}
|
|
|
|
fn resolve_model<'a>(&'a self, active_model: &'a str) -> String {
|
|
self.model_override
|
|
.clone()
|
|
.filter(|model| !model.trim().is_empty())
|
|
.unwrap_or_else(|| active_model.to_string())
|
|
}
|
|
}
|
|
|
|
const MAX_TRANSCRIPT_MESSAGE_CHARS: usize = 1024;
|
|
const MAX_TRANSCRIPT_MESSAGES: usize = 32;
|
|
const COMPRESSION_METADATA_KEY: &str = "compression";
|
|
|
|
#[derive(Debug, Default)]
|
|
struct StreamingMessageState {
|
|
full_text: String,
|
|
last_tool_calls: Option<Vec<ToolCall>>,
|
|
finished: bool,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct StreamDiff {
|
|
text: Option<TextDelta>,
|
|
tool_calls: Option<Vec<ToolCall>>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
struct TextDelta {
|
|
content: String,
|
|
mode: TextDeltaKind,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
enum TextDeltaKind {
|
|
Append,
|
|
Replace,
|
|
}
|
|
|
|
impl StreamingMessageState {
|
|
fn new() -> Self {
|
|
Self::default()
|
|
}
|
|
|
|
fn ingest(&mut self, chunk: &ChatResponse) -> StreamDiff {
|
|
if self.finished {
|
|
return StreamDiff {
|
|
text: None,
|
|
tool_calls: None,
|
|
};
|
|
}
|
|
|
|
let mut text_delta = None;
|
|
let incoming = chunk.message.content.clone();
|
|
|
|
if incoming != self.full_text {
|
|
if incoming.starts_with(&self.full_text) {
|
|
let delta = incoming[self.full_text.len()..].to_string();
|
|
if !delta.is_empty() {
|
|
text_delta = Some(TextDelta {
|
|
content: delta,
|
|
mode: TextDeltaKind::Append,
|
|
});
|
|
}
|
|
} else {
|
|
text_delta = Some(TextDelta {
|
|
content: incoming.clone(),
|
|
mode: TextDeltaKind::Replace,
|
|
});
|
|
}
|
|
self.full_text = incoming;
|
|
}
|
|
|
|
let mut tool_delta = None;
|
|
if let Some(tool_calls) = chunk.message.tool_calls.clone() {
|
|
if tool_calls.is_empty() {
|
|
let previously_had_calls = self
|
|
.last_tool_calls
|
|
.as_ref()
|
|
.map(|prev| !prev.is_empty())
|
|
.unwrap_or(false);
|
|
if previously_had_calls {
|
|
tool_delta = Some(Vec::new());
|
|
}
|
|
self.last_tool_calls = None;
|
|
} else {
|
|
let is_new = self
|
|
.last_tool_calls
|
|
.as_ref()
|
|
.map(|prev| prev != &tool_calls)
|
|
.unwrap_or(true);
|
|
if is_new {
|
|
tool_delta = Some(tool_calls.clone());
|
|
}
|
|
self.last_tool_calls = Some(tool_calls);
|
|
}
|
|
}
|
|
|
|
StreamDiff {
|
|
text: text_delta,
|
|
tool_calls: tool_delta,
|
|
}
|
|
}
|
|
|
|
fn mark_finished(&mut self) {
|
|
self.finished = true;
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolConsentResolution {
|
|
pub request_id: Uuid,
|
|
pub message_id: Uuid,
|
|
pub tool_name: String,
|
|
pub scope: ConsentScope,
|
|
pub tool_calls: Vec<ToolCall>,
|
|
}
|
|
|
|
fn extract_resource_content(value: &Value) -> Option<String> {
|
|
match value {
|
|
Value::Null => Some(String::new()),
|
|
Value::Bool(flag) => Some(flag.to_string()),
|
|
Value::Number(num) => Some(num.to_string()),
|
|
Value::String(text) => Some(text.clone()),
|
|
Value::Array(items) => {
|
|
let mut segments = Vec::new();
|
|
for item in items {
|
|
if let Some(segment) =
|
|
extract_resource_content(item).filter(|segment| !segment.is_empty())
|
|
{
|
|
segments.push(segment);
|
|
}
|
|
}
|
|
if segments.is_empty() {
|
|
None
|
|
} else {
|
|
Some(segments.join("\n"))
|
|
}
|
|
}
|
|
Value::Object(map) => {
|
|
const PREFERRED_FIELDS: [&str; 6] =
|
|
["content", "contents", "text", "value", "body", "data"];
|
|
for key in PREFERRED_FIELDS.iter() {
|
|
if let Some(text) = map
|
|
.get(*key)
|
|
.and_then(extract_resource_content)
|
|
.filter(|text| !text.is_empty())
|
|
{
|
|
return Some(text);
|
|
}
|
|
}
|
|
|
|
if let Some(text) = map
|
|
.get("chunks")
|
|
.and_then(extract_resource_content)
|
|
.filter(|text| !text.is_empty())
|
|
{
|
|
return Some(text);
|
|
}
|
|
|
|
None
|
|
}
|
|
}
|
|
}
|
|
|
|
pub struct SessionController {
|
|
provider: Arc<dyn Provider>,
|
|
conversation: ConversationManager,
|
|
model_manager: ModelManager,
|
|
input_buffer: InputBuffer,
|
|
formatter: MessageFormatter,
|
|
config: Arc<TokioMutex<Config>>,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
schema_validator: Arc<SchemaValidator>,
|
|
mcp_client: Arc<dyn McpClient>,
|
|
named_mcp_clients: HashMap<String, Arc<dyn McpClient>>,
|
|
storage: Arc<StorageManager>,
|
|
vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
master_key: Option<Arc<Vec<u8>>>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
current_mode: Mode,
|
|
missing_oauth_servers: Vec<String>,
|
|
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
|
pending_tool_requests: HashMap<Uuid, PendingToolRequest>,
|
|
stream_states: HashMap<Uuid, StreamingMessageState>,
|
|
usage_ledger: Arc<TokioMutex<UsageLedger>>,
|
|
last_compression: Option<CompressionReport>,
|
|
}
|
|
|
|
async fn build_tools(
|
|
config: Arc<TokioMutex<Config>>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
consent_manager: Arc<Mutex<ConsentManager>>,
|
|
_credential_manager: Option<Arc<CredentialManager>>,
|
|
_vault: Option<Arc<Mutex<VaultHandle>>>,
|
|
) -> Result<(Arc<ToolRegistry>, Arc<SchemaValidator>)> {
|
|
let mut registry = ToolRegistry::new(config.clone(), ui);
|
|
let mut validator = SchemaValidator::new();
|
|
// Acquire config asynchronously to avoid blocking the async runtime.
|
|
let config_guard = config.lock().await;
|
|
|
|
for (name, schema) in get_builtin_schemas() {
|
|
if let Err(err) = validator.register_schema(&name, schema) {
|
|
warn!("Failed to register built-in schema {name}: {err}");
|
|
}
|
|
}
|
|
|
|
let active_provider_id = config_guard.general.default_provider.clone();
|
|
|
|
let web_search_settings = if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool_name_matches(tool, WEB_SEARCH_TOOL_NAME))
|
|
&& config_guard.tools.web_search.enabled
|
|
&& config_guard.privacy.enable_remote_search
|
|
{
|
|
match compute_web_search_settings(&config_guard, &active_provider_id) {
|
|
Ok(settings) => settings,
|
|
Err(err) => {
|
|
warn!("Skipping web_search tool: {}", err);
|
|
None
|
|
}
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
if let Some(settings) = web_search_settings {
|
|
let tool = WebSearchTool::new(consent_manager.clone(), settings);
|
|
registry.register(tool)?;
|
|
}
|
|
|
|
// Register web_scrape tool if allowed.
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool_name_matches(tool, "web_scrape"))
|
|
&& config_guard.tools.web_search.enabled // reuse web_search toggle for simplicity
|
|
&& config_guard.privacy.enable_remote_search
|
|
{
|
|
let tool = WebScrapeTool::new();
|
|
registry.register(tool)?;
|
|
}
|
|
|
|
if enable_code_tools
|
|
&& config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|tool| tool == "code_exec")
|
|
&& config_guard.tools.code_exec.enabled
|
|
{
|
|
let tool = CodeExecTool::new(config_guard.tools.code_exec.allowed_languages.clone());
|
|
registry.register(tool)?;
|
|
}
|
|
|
|
registry.register(ResourcesListTool)?;
|
|
registry.register(ResourcesGetTool)?;
|
|
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|t| t == "file_write")
|
|
{
|
|
registry.register(ResourcesWriteTool)?;
|
|
}
|
|
if config_guard
|
|
.security
|
|
.allowed_tools
|
|
.iter()
|
|
.any(|t| t == "file_delete")
|
|
{
|
|
registry.register(ResourcesDeleteTool)?;
|
|
}
|
|
|
|
for tool in registry.all() {
|
|
if let Err(err) = validator.register_schema(tool.name(), tool.schema()) {
|
|
warn!("Failed to register schema for {}: {err}", tool.name());
|
|
}
|
|
}
|
|
|
|
Ok((Arc::new(registry), Arc::new(validator)))
|
|
}
|
|
|
|
impl SessionController {
|
|
async fn create_mcp_clients(
|
|
config: Arc<TokioMutex<Config>>,
|
|
tool_registry: Arc<ToolRegistry>,
|
|
schema_validator: Arc<SchemaValidator>,
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
initial_mode: Mode,
|
|
) -> Result<(
|
|
Arc<dyn McpClient>,
|
|
HashMap<String, Arc<dyn McpClient>>,
|
|
Vec<String>,
|
|
)> {
|
|
let guard = config.lock().await;
|
|
let config_arc = Arc::new(guard.clone());
|
|
let factory = McpClientFactory::new(config_arc.clone(), tool_registry, schema_validator);
|
|
|
|
let mut missing_oauth_servers = Vec::new();
|
|
let primary_runtime = if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
|
let (runtime, missing) =
|
|
Self::runtime_secrets_for_server(credential_manager.clone(), primary_cfg).await?;
|
|
if missing {
|
|
missing_oauth_servers.push(primary_cfg.name.clone());
|
|
}
|
|
runtime
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let base_client = factory.create_with_secrets(primary_runtime)?;
|
|
let primary: Arc<dyn McpClient> =
|
|
Arc::new(PermissionLayer::new(base_client, config_arc.clone()));
|
|
primary.set_mode(initial_mode).await?;
|
|
|
|
let mut clients: HashMap<String, Arc<dyn McpClient>> = HashMap::new();
|
|
if let Some(primary_cfg) = guard.effective_mcp_servers().first() {
|
|
clients.insert(primary_cfg.name.clone(), Arc::clone(&primary));
|
|
}
|
|
|
|
for server_cfg in guard.effective_mcp_servers().iter().skip(1) {
|
|
let (runtime, missing) =
|
|
Self::runtime_secrets_for_server(credential_manager.clone(), server_cfg).await?;
|
|
if missing {
|
|
missing_oauth_servers.push(server_cfg.name.clone());
|
|
}
|
|
|
|
match RemoteMcpClient::new_with_runtime(server_cfg, runtime) {
|
|
Ok(remote) => {
|
|
let client: Arc<dyn McpClient> =
|
|
Arc::new(PermissionLayer::new(Box::new(remote), config_arc.clone()));
|
|
if let Err(err) = client.set_mode(initial_mode).await {
|
|
warn!(
|
|
"Failed to initialize MCP server '{}' in mode {:?}: {}",
|
|
server_cfg.name, initial_mode, err
|
|
);
|
|
}
|
|
clients.insert(server_cfg.name.clone(), Arc::clone(&client));
|
|
}
|
|
Err(err) => warn!(
|
|
"Failed to initialize MCP server '{}': {}",
|
|
server_cfg.name, err
|
|
),
|
|
}
|
|
}
|
|
|
|
drop(guard);
|
|
|
|
Ok((primary, clients, missing_oauth_servers))
|
|
}
|
|
|
|
async fn runtime_secrets_for_server(
|
|
credential_manager: Option<Arc<CredentialManager>>,
|
|
server: &McpServerConfig,
|
|
) -> Result<(Option<McpRuntimeSecrets>, bool)> {
|
|
if let Some(oauth) = &server.oauth {
|
|
if let Some(manager) = credential_manager {
|
|
match manager.load_oauth_token(&server.name).await? {
|
|
Some(token) => {
|
|
if token.access_token.trim().is_empty() || token.is_expired(Utc::now()) {
|
|
return Ok((None, true));
|
|
}
|
|
let mut secrets = McpRuntimeSecrets::default();
|
|
if let Some(env_name) = oauth.token_env.as_deref() {
|
|
secrets
|
|
.env_overrides
|
|
.insert(env_name.to_string(), token.access_token.clone());
|
|
}
|
|
if matches!(
|
|
server.transport.to_ascii_lowercase().as_str(),
|
|
"http" | "websocket"
|
|
) {
|
|
let header_value =
|
|
format!("{}{}", oauth.header_prefix(), token.access_token);
|
|
secrets.http_header =
|
|
Some((oauth.header_name().to_string(), header_value));
|
|
}
|
|
Ok((Some(secrets), false))
|
|
}
|
|
None => Ok((None, true)),
|
|
}
|
|
} else {
|
|
Ok((None, true))
|
|
}
|
|
} else {
|
|
Ok((None, false))
|
|
}
|
|
}
|
|
|
|
pub async fn new(
|
|
provider: Arc<dyn Provider>,
|
|
config: Config,
|
|
storage: Arc<StorageManager>,
|
|
ui: Arc<dyn UiController>,
|
|
enable_code_tools: bool,
|
|
event_tx: Option<UnboundedSender<ControllerEvent>>,
|
|
) -> Result<Self> {
|
|
let config_arc = Arc::new(TokioMutex::new(config));
|
|
// Acquire the config asynchronously to avoid blocking the runtime.
|
|
let config_guard = config_arc.lock().await;
|
|
|
|
let model = config_guard
|
|
.general
|
|
.default_model
|
|
.clone()
|
|
.unwrap_or_else(|| "ollama/default".to_string());
|
|
|
|
let mut vault_handle: Option<Arc<Mutex<VaultHandle>>> = None;
|
|
let mut master_key: Option<Arc<Vec<u8>>> = None;
|
|
let mut credential_manager: Option<Arc<CredentialManager>> = None;
|
|
|
|
if config_guard.privacy.encrypt_local_data {
|
|
let base_dir = storage
|
|
.database_path()
|
|
.parent()
|
|
.map(|p| p.to_path_buf())
|
|
.or_else(dirs::data_local_dir)
|
|
.unwrap_or_else(|| PathBuf::from("."));
|
|
let secure_path = base_dir.join("encrypted_data.json");
|
|
let handle = match env::var("OWLEN_MASTER_PASSWORD") {
|
|
Ok(password) if !password.is_empty() => {
|
|
encryption::unlock_with_password(secure_path, &password)?
|
|
}
|
|
_ => encryption::unlock_interactive(secure_path)?,
|
|
};
|
|
let master = Arc::new(handle.data.master_key.clone());
|
|
master_key = Some(master.clone());
|
|
vault_handle = Some(Arc::new(Mutex::new(handle)));
|
|
credential_manager = Some(Arc::new(CredentialManager::new(storage.clone(), master)));
|
|
}
|
|
|
|
let consent_manager = if let Some(ref vault) = vault_handle {
|
|
Arc::new(Mutex::new(ConsentManager::from_vault(vault)))
|
|
} else {
|
|
Arc::new(Mutex::new(ConsentManager::new()))
|
|
};
|
|
|
|
let conversation = ConversationManager::with_history_capacity(
|
|
model,
|
|
config_guard.storage.max_saved_sessions,
|
|
);
|
|
let formatter = MessageFormatter::new(
|
|
config_guard.ui.wrap_column as usize,
|
|
config_guard.ui.role_label_mode,
|
|
)
|
|
.with_preserve_empty(config_guard.ui.word_wrap);
|
|
let input_buffer = InputBuffer::new(
|
|
config_guard.input.history_size,
|
|
config_guard.input.multiline,
|
|
config_guard.input.tab_width,
|
|
);
|
|
let model_manager = ModelManager::new(config_guard.general.model_cache_ttl());
|
|
|
|
drop(config_guard); // Release the lock before calling build_tools
|
|
|
|
let initial_mode = if enable_code_tools {
|
|
Mode::Code
|
|
} else {
|
|
Mode::Chat
|
|
};
|
|
|
|
let (tool_registry, schema_validator) = build_tools(
|
|
config_arc.clone(),
|
|
ui.clone(),
|
|
enable_code_tools,
|
|
consent_manager.clone(),
|
|
credential_manager.clone(),
|
|
vault_handle.clone(),
|
|
)
|
|
.await?;
|
|
|
|
let (mcp_client, named_mcp_clients, missing_oauth_servers) = Self::create_mcp_clients(
|
|
config_arc.clone(),
|
|
tool_registry.clone(),
|
|
schema_validator.clone(),
|
|
credential_manager.clone(),
|
|
initial_mode,
|
|
)
|
|
.await?;
|
|
|
|
let usage_ledger_path = storage
|
|
.database_path()
|
|
.parent()
|
|
.map(|dir| dir.join("usage-ledger.json"))
|
|
.unwrap_or_else(|| PathBuf::from("usage-ledger.json"));
|
|
|
|
let usage_ledger_instance =
|
|
match UsageLedger::load_or_default(usage_ledger_path.clone()).await {
|
|
Ok(ledger) => ledger,
|
|
Err(err) => {
|
|
warn!(
|
|
"Failed to load usage ledger at {}: {err}. Starting with an empty ledger.",
|
|
usage_ledger_path.display()
|
|
);
|
|
UsageLedger::empty(usage_ledger_path)
|
|
}
|
|
};
|
|
let usage_ledger = Arc::new(TokioMutex::new(usage_ledger_instance));
|
|
|
|
Ok(Self {
|
|
provider,
|
|
conversation,
|
|
model_manager,
|
|
input_buffer,
|
|
formatter,
|
|
config: config_arc,
|
|
consent_manager,
|
|
tool_registry,
|
|
schema_validator,
|
|
mcp_client,
|
|
named_mcp_clients,
|
|
storage,
|
|
vault: vault_handle,
|
|
master_key,
|
|
credential_manager,
|
|
ui,
|
|
enable_code_tools,
|
|
current_mode: initial_mode,
|
|
missing_oauth_servers,
|
|
event_tx,
|
|
pending_tool_requests: HashMap::new(),
|
|
stream_states: HashMap::new(),
|
|
usage_ledger,
|
|
last_compression: None,
|
|
})
|
|
}
|
|
|
|
pub fn conversation(&self) -> &Conversation {
|
|
self.conversation.active()
|
|
}
|
|
|
|
pub fn conversation_mut(&mut self) -> &mut ConversationManager {
|
|
&mut self.conversation
|
|
}
|
|
|
|
pub fn last_compression(&self) -> Option<CompressionReport> {
|
|
self.last_compression.clone()
|
|
}
|
|
|
|
pub fn input_buffer(&self) -> &InputBuffer {
|
|
&self.input_buffer
|
|
}
|
|
|
|
pub fn input_buffer_mut(&mut self) -> &mut InputBuffer {
|
|
&mut self.input_buffer
|
|
}
|
|
|
|
pub fn formatter(&self) -> &MessageFormatter {
|
|
&self.formatter
|
|
}
|
|
|
|
pub async fn set_formatter_wrap_width(&mut self, width: usize) {
|
|
self.formatter.set_wrap_width(width);
|
|
}
|
|
|
|
pub fn set_role_label_mode(&mut self, mode: RoleLabelDisplay) {
|
|
self.formatter.set_role_label_mode(mode);
|
|
}
|
|
|
|
/// Return the configured resource references aggregated across scopes.
|
|
pub async fn configured_resources(&self) -> Vec<McpResourceConfig> {
|
|
let guard = self.config.lock().await;
|
|
guard.effective_mcp_resources().to_vec()
|
|
}
|
|
|
|
/// Resolve a resource reference of the form `server:uri` (optionally prefixed with `@`).
|
|
pub async fn resolve_resource_reference(&self, reference: &str) -> Result<Option<String>> {
|
|
let (server, uri) = match Self::split_resource_reference(reference) {
|
|
Some(parts) => parts,
|
|
None => return Ok(None),
|
|
};
|
|
|
|
let resource_defined = {
|
|
let guard = self.config.lock().await;
|
|
guard.find_resource(&server, &uri).is_some()
|
|
};
|
|
|
|
if !resource_defined {
|
|
return Ok(None);
|
|
}
|
|
|
|
let client = self
|
|
.named_mcp_clients
|
|
.get(&server)
|
|
.cloned()
|
|
.ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{}' referenced by resource '{}' is not available",
|
|
server, uri
|
|
))
|
|
})?;
|
|
|
|
let call = McpToolCall {
|
|
name: "resources_get".to_string(),
|
|
arguments: json!({ "uri": uri, "path": uri }),
|
|
};
|
|
let response = client.call_tool(call).await?;
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
return Ok(Some(text));
|
|
}
|
|
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(Some(formatted))
|
|
}
|
|
|
|
fn split_resource_reference(reference: &str) -> Option<(String, String)> {
|
|
let trimmed = reference.trim();
|
|
let without_prefix = trimmed.strip_prefix('@').unwrap_or(trimmed);
|
|
let (server, uri) = without_prefix.split_once(':')?;
|
|
if server.is_empty() || uri.is_empty() {
|
|
return None;
|
|
}
|
|
Some((server.to_string(), uri.to_string()))
|
|
}
|
|
|
|
async fn persist_usage_serialized(path: PathBuf, serialized: String) {
|
|
if let Some(parent) = path.parent() {
|
|
if let Err(err) = fs::create_dir_all(parent).await {
|
|
warn!(
|
|
"Failed to create usage ledger directory {}: {}",
|
|
parent.display(),
|
|
err
|
|
);
|
|
return;
|
|
}
|
|
}
|
|
|
|
if let Err(err) = fs::write(&path, serialized).await {
|
|
warn!("Failed to write usage ledger {}: {}", path.display(), err);
|
|
}
|
|
}
|
|
|
|
fn parse_quota_value(value: &Value) -> Option<u64> {
|
|
match value {
|
|
Value::Number(num) => num.as_u64(),
|
|
Value::String(text) => text.trim().parse::<u64>().ok(),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn quota_from_config(config: &Config, provider: &str) -> UsageQuota {
|
|
let mut quota = UsageQuota::default();
|
|
|
|
if let Some(entry) = config.providers.get(provider) {
|
|
if let Some(value) = entry.extra.get("hourly_quota_tokens") {
|
|
quota.hourly_quota_tokens = Self::parse_quota_value(value);
|
|
}
|
|
if let Some(value) = entry.extra.get("weekly_quota_tokens") {
|
|
quota.weekly_quota_tokens = Self::parse_quota_value(value);
|
|
}
|
|
}
|
|
|
|
quota
|
|
}
|
|
|
|
pub async fn record_usage_sample(
|
|
&self,
|
|
usage: &crate::types::TokenUsage,
|
|
) -> Option<UsageSnapshot> {
|
|
if usage.total_tokens == 0 {
|
|
return None;
|
|
}
|
|
|
|
let provider_name = self.provider.name().to_string();
|
|
if provider_name.trim().is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let quotas = {
|
|
let guard = self.config.lock().await;
|
|
Self::quota_from_config(&guard, &provider_name)
|
|
};
|
|
|
|
let timestamp = SystemTime::now();
|
|
let mut serialized_payload: Option<(PathBuf, String)> = None;
|
|
|
|
let snapshot = {
|
|
let mut ledger = self.usage_ledger.lock().await;
|
|
ledger.record(&provider_name, usage, timestamp);
|
|
let snapshot = ledger.snapshot(&provider_name, quotas, timestamp);
|
|
match ledger.serialize() {
|
|
Ok(payload) => {
|
|
serialized_payload = Some((ledger.path().to_path_buf(), payload));
|
|
}
|
|
Err(err) => warn!("Failed to serialize usage ledger: {}", err),
|
|
}
|
|
snapshot
|
|
};
|
|
|
|
if let Some((path, payload)) = serialized_payload {
|
|
Self::persist_usage_serialized(path, payload).await;
|
|
}
|
|
|
|
Some(snapshot)
|
|
}
|
|
|
|
pub async fn current_usage_snapshot(&self) -> Option<UsageSnapshot> {
|
|
let provider_name = self.provider.name().to_string();
|
|
if provider_name.trim().is_empty() {
|
|
return None;
|
|
}
|
|
|
|
let quotas = {
|
|
let guard = self.config.lock().await;
|
|
Self::quota_from_config(&guard, &provider_name)
|
|
};
|
|
|
|
let now = SystemTime::now();
|
|
let ledger = self.usage_ledger.lock().await;
|
|
Some(ledger.snapshot(&provider_name, quotas, now))
|
|
}
|
|
|
|
pub async fn usage_overview(&self) -> Vec<UsageSnapshot> {
|
|
let quota_map = {
|
|
let guard = self.config.lock().await;
|
|
guard
|
|
.providers
|
|
.keys()
|
|
.map(|name| (name.clone(), Self::quota_from_config(&guard, name)))
|
|
.collect::<HashMap<_, _>>()
|
|
};
|
|
|
|
let now = SystemTime::now();
|
|
let mut provider_names: HashSet<String> = quota_map.keys().cloned().collect();
|
|
|
|
let ledger = self.usage_ledger.lock().await;
|
|
provider_names.extend(ledger.provider_keys().cloned());
|
|
|
|
provider_names
|
|
.into_iter()
|
|
.map(|provider| {
|
|
let quota = quota_map.get(&provider).cloned().unwrap_or_default();
|
|
ledger.snapshot(&provider, quota, now)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
// Asynchronous access to the configuration (used internally).
|
|
pub async fn config_async(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
self.config.lock().await
|
|
}
|
|
|
|
// Synchronous, blocking access to the configuration. This is kept for the TUI
|
|
// which expects `controller.config()` to return a reference without awaiting.
|
|
// Provide a blocking configuration lock that is safe to call from async
|
|
// contexts by using `tokio::task::block_in_place`. This allows the current
|
|
// thread to be blocked without violating Tokio's runtime constraints.
|
|
pub fn config(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
tokio::task::block_in_place(|| self.config.blocking_lock())
|
|
}
|
|
|
|
// Synchronous mutable access, mirroring `config()` but allowing mutation.
|
|
pub fn config_mut(&self) -> tokio::sync::MutexGuard<'_, Config> {
|
|
tokio::task::block_in_place(|| self.config.blocking_lock())
|
|
}
|
|
|
|
pub fn config_cloned(&self) -> Arc<TokioMutex<Config>> {
|
|
self.config.clone()
|
|
}
|
|
|
|
pub async fn compress_now(&mut self) -> Result<Option<CompressionReport>> {
|
|
let settings = {
|
|
let guard = self.config.lock().await;
|
|
guard.chat.clone()
|
|
};
|
|
let options = CompressionOptions::from_settings(&settings);
|
|
self.perform_compression(options, false).await
|
|
}
|
|
|
|
pub async fn maybe_auto_compress(&mut self) -> Result<Option<CompressionReport>> {
|
|
let settings = {
|
|
let guard = self.config.lock().await;
|
|
if !guard.chat.auto_compress {
|
|
return Ok(None);
|
|
}
|
|
guard.chat.clone()
|
|
};
|
|
let options = CompressionOptions::from_settings(&settings);
|
|
self.perform_compression(options, true).await
|
|
}
|
|
|
|
async fn perform_compression(
|
|
&mut self,
|
|
options: CompressionOptions,
|
|
automated: bool,
|
|
) -> Result<Option<CompressionReport>> {
|
|
let mut final_report = None;
|
|
let mut iterations = 0usize;
|
|
|
|
loop {
|
|
iterations += 1;
|
|
if iterations > 4 {
|
|
break;
|
|
}
|
|
|
|
let snapshot = self.conversation.active().clone();
|
|
let total_tokens = estimate_tokens(&snapshot.messages);
|
|
if total_tokens <= options.trigger_tokens {
|
|
break;
|
|
}
|
|
|
|
if snapshot.messages.len() <= options.retain_recent + 1 {
|
|
break;
|
|
}
|
|
|
|
let split_index = snapshot
|
|
.messages
|
|
.len()
|
|
.saturating_sub(options.retain_recent);
|
|
if split_index == 0 {
|
|
break;
|
|
}
|
|
|
|
let older_messages = &snapshot.messages[..split_index];
|
|
if older_messages.len() < options.min_chunk_messages() {
|
|
break;
|
|
}
|
|
|
|
if older_messages
|
|
.iter()
|
|
.all(|msg| msg.metadata.contains_key(COMPRESSION_METADATA_KEY))
|
|
{
|
|
break;
|
|
}
|
|
|
|
let model_used = options.resolve_model(&snapshot.model);
|
|
let summary = self
|
|
.generate_summary(older_messages, &options, &model_used)
|
|
.await;
|
|
|
|
let summary_body = summary.trim();
|
|
let intro = "### Conversation summary";
|
|
let footer = if automated {
|
|
"_This summary was generated automatically to preserve context._"
|
|
} else {
|
|
"_Manual compression complete._"
|
|
};
|
|
let content = if summary_body.is_empty() {
|
|
format!(
|
|
"{intro}\n\n_Compressed {} prior messages._\n\n{footer}",
|
|
older_messages.len()
|
|
)
|
|
} else {
|
|
format!(
|
|
"{intro}\n\n{summary_body}\n\n_Compressed {} prior messages._\n\n{footer}",
|
|
older_messages.len()
|
|
)
|
|
};
|
|
|
|
let mut summary_message = Message::system(content);
|
|
let compressed_ids: Vec<String> = older_messages
|
|
.iter()
|
|
.map(|msg| msg.id.to_string())
|
|
.collect();
|
|
let summary_tokens = estimate_message_tokens(&summary_message);
|
|
let retained_tokens = estimate_tokens(&snapshot.messages[split_index..]);
|
|
let updated_tokens = summary_tokens.saturating_add(retained_tokens);
|
|
let timestamp = Utc::now();
|
|
let metadata = json!({
|
|
"strategy": match options.strategy {
|
|
CompressionStrategy::Provider => "provider",
|
|
CompressionStrategy::Local => "local",
|
|
},
|
|
"automated": automated,
|
|
"compressed_message_ids": compressed_ids,
|
|
"compressed_count": older_messages.len(),
|
|
"retain_recent": options.retain_recent,
|
|
"trigger_tokens": options.trigger_tokens,
|
|
"estimated_tokens_before": total_tokens,
|
|
"model": model_used,
|
|
"estimated_tokens_after": updated_tokens,
|
|
"timestamp": timestamp.to_rfc3339(),
|
|
});
|
|
summary_message
|
|
.metadata
|
|
.insert(COMPRESSION_METADATA_KEY.to_string(), metadata);
|
|
|
|
let mut new_messages =
|
|
Vec::with_capacity(snapshot.messages.len() - older_messages.len() + 1);
|
|
new_messages.push(summary_message.clone());
|
|
new_messages.extend_from_slice(&snapshot.messages[split_index..]);
|
|
self.conversation.replace_active_messages(new_messages);
|
|
let report = CompressionReport {
|
|
summary_message_id: summary_message.id,
|
|
compressed_messages: older_messages.len(),
|
|
estimated_tokens_before: total_tokens,
|
|
estimated_tokens_after: updated_tokens,
|
|
strategy: options.strategy,
|
|
model_used: model_used.clone(),
|
|
retained_recent: options.retain_recent,
|
|
automated,
|
|
timestamp,
|
|
};
|
|
|
|
self.last_compression = Some(report.clone());
|
|
if automated {
|
|
info!(
|
|
"auto compression reduced transcript from {} to {} tokens (compressed {} messages)",
|
|
total_tokens, updated_tokens, report.compressed_messages
|
|
);
|
|
}
|
|
self.emit_compression_event(report.clone());
|
|
final_report = Some(report.clone());
|
|
|
|
if updated_tokens >= total_tokens {
|
|
break;
|
|
}
|
|
if updated_tokens <= options.trigger_tokens {
|
|
break;
|
|
}
|
|
|
|
// Continue loop to attempt further reduction if needed.
|
|
}
|
|
|
|
Ok(final_report)
|
|
}
|
|
|
|
async fn generate_summary(
|
|
&self,
|
|
slice: &[Message],
|
|
options: &CompressionOptions,
|
|
model: &str,
|
|
) -> String {
|
|
match options.strategy {
|
|
CompressionStrategy::Provider => {
|
|
match self.generate_provider_summary(slice, model).await {
|
|
Ok(content) if !content.trim().is_empty() => content,
|
|
Ok(_) => local_summary(slice),
|
|
Err(err) => {
|
|
warn!(
|
|
"Falling back to local compression: provider summary failed ({})",
|
|
err
|
|
);
|
|
local_summary(slice)
|
|
}
|
|
}
|
|
}
|
|
CompressionStrategy::Local => local_summary(slice),
|
|
}
|
|
}
|
|
|
|
async fn generate_provider_summary(&self, slice: &[Message], model: &str) -> Result<String> {
|
|
let mut prompt_messages = Vec::new();
|
|
prompt_messages.push(Message::system("You are Owlen's transcript compactor. Summarize the provided conversation excerpt into concise markdown with sections for context, decisions, outstanding tasks, and facts that must be preserved. Avoid referring to removed content explicitly.".to_string()));
|
|
let transcript = build_transcript(slice);
|
|
prompt_messages.push(Message::user(transcript));
|
|
|
|
let request = ChatRequest {
|
|
model: model.to_string(),
|
|
messages: prompt_messages,
|
|
parameters: ChatParameters::default(),
|
|
tools: None,
|
|
};
|
|
|
|
let response = self.provider.send_prompt(request).await?;
|
|
Ok(response.message.content)
|
|
}
|
|
|
|
fn emit_compression_event(&self, report: CompressionReport) {
|
|
if let Some(tx) = &self.event_tx {
|
|
let _ = tx.send(ControllerEvent::CompressionCompleted { report });
|
|
}
|
|
}
|
|
|
|
pub async fn reload_mcp_clients(&mut self) -> Result<()> {
|
|
let (primary, named, missing) = Self::create_mcp_clients(
|
|
self.config.clone(),
|
|
self.tool_registry.clone(),
|
|
self.schema_validator.clone(),
|
|
self.credential_manager.clone(),
|
|
self.current_mode,
|
|
)
|
|
.await?;
|
|
self.mcp_client = primary;
|
|
self.named_mcp_clients = named;
|
|
self.missing_oauth_servers = missing;
|
|
Ok(())
|
|
}
|
|
|
|
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
consent.grant_consent(tool_name, data_types, endpoints);
|
|
|
|
let Some(vault) = &self.vault else {
|
|
return;
|
|
};
|
|
if let Err(e) = consent.persist_to_vault(vault) {
|
|
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
|
}
|
|
}
|
|
|
|
pub fn grant_consent_with_scope(
|
|
&self,
|
|
tool_name: &str,
|
|
data_types: Vec<String>,
|
|
endpoints: Vec<String>,
|
|
scope: crate::consent::ConsentScope,
|
|
) {
|
|
let mut consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
let is_permanent = matches!(scope, crate::consent::ConsentScope::Permanent);
|
|
consent.grant_consent_with_scope(tool_name, data_types, endpoints, scope);
|
|
|
|
// Only persist to vault for permanent consent
|
|
if !is_permanent {
|
|
return;
|
|
}
|
|
let Some(vault) = &self.vault else {
|
|
return;
|
|
};
|
|
if let Err(e) = consent.persist_to_vault(vault) {
|
|
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
|
}
|
|
}
|
|
|
|
pub fn check_tools_consent_needed(
|
|
&self,
|
|
tool_calls: &[ToolCall],
|
|
) -> Vec<(String, Vec<String>, Vec<String>)> {
|
|
let consent = self
|
|
.consent_manager
|
|
.lock()
|
|
.expect("Consent manager mutex poisoned");
|
|
let mut needs_consent = Vec::new();
|
|
let mut seen_tools = std::collections::HashSet::new();
|
|
|
|
for tool_call in tool_calls {
|
|
let canonical = canonical_tool_name(tool_call.name.as_str()).to_string();
|
|
if seen_tools.contains(&canonical) {
|
|
continue;
|
|
}
|
|
seen_tools.insert(canonical.clone());
|
|
|
|
let (data_types, endpoints) = match canonical.as_str() {
|
|
WEB_SEARCH_TOOL_NAME => (
|
|
vec!["search query".to_string()],
|
|
vec!["cloud provider".to_string()],
|
|
),
|
|
"code_exec" => (
|
|
vec!["code to execute".to_string()],
|
|
vec!["local sandbox".to_string()],
|
|
),
|
|
"resources_write" | "file_write" => (
|
|
vec!["file paths".to_string(), "file content".to_string()],
|
|
vec!["local filesystem".to_string()],
|
|
),
|
|
"resources_delete" | "file_delete" => (
|
|
vec!["file paths".to_string()],
|
|
vec!["local filesystem".to_string()],
|
|
),
|
|
_ => (vec![], vec![]),
|
|
};
|
|
|
|
if let Some((tool_name, dt, ep)) =
|
|
consent.check_if_consent_needed(&tool_call.name, data_types, endpoints)
|
|
{
|
|
needs_consent.push((tool_name, dt, ep));
|
|
}
|
|
}
|
|
|
|
needs_consent
|
|
}
|
|
|
|
pub async fn save_active_session(
|
|
&self,
|
|
name: Option<String>,
|
|
description: Option<String>,
|
|
) -> Result<Uuid> {
|
|
self.conversation
|
|
.save_active_with_description(&self.storage, name, description)
|
|
.await
|
|
}
|
|
|
|
pub async fn save_active_session_simple(&self, name: Option<String>) -> Result<Uuid> {
|
|
self.conversation.save_active(&self.storage, name).await
|
|
}
|
|
|
|
pub async fn load_saved_session(&mut self, id: Uuid) -> Result<()> {
|
|
self.conversation.load_saved(&self.storage, id).await
|
|
}
|
|
|
|
pub async fn list_saved_sessions(&self) -> Result<Vec<SessionMeta>> {
|
|
ConversationManager::list_saved_sessions(&self.storage).await
|
|
}
|
|
|
|
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
|
self.storage.delete_session(id).await
|
|
}
|
|
|
|
pub async fn clear_secure_data(&self) -> Result<()> {
|
|
// ... (implementation remains the same)
|
|
Ok(())
|
|
}
|
|
|
|
pub fn persist_consent(&self) -> Result<()> {
|
|
// ... (implementation remains the same)
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
|
{
|
|
let mut config = self.config.lock().await;
|
|
let canonical = canonical_tool_name(tool);
|
|
match canonical {
|
|
WEB_SEARCH_TOOL_NAME => {
|
|
config.tools.web_search.enabled = enabled;
|
|
config.privacy.enable_remote_search = enabled;
|
|
}
|
|
"code_exec" => config.tools.code_exec.enabled = enabled,
|
|
_ => return Err(Error::InvalidInput(format!("Unknown tool: {tool}"))),
|
|
}
|
|
}
|
|
self.rebuild_tools().await
|
|
}
|
|
|
|
pub fn consent_manager(&self) -> Arc<Mutex<ConsentManager>> {
|
|
self.consent_manager.clone()
|
|
}
|
|
|
|
pub fn tool_registry(&self) -> Arc<ToolRegistry> {
|
|
self.tool_registry.clone()
|
|
}
|
|
|
|
pub fn schema_validator(&self) -> Arc<SchemaValidator> {
|
|
self.schema_validator.clone()
|
|
}
|
|
|
|
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
|
self.credential_manager.clone()
|
|
}
|
|
|
|
pub fn pending_oauth_servers(&self) -> Vec<String> {
|
|
self.missing_oauth_servers.clone()
|
|
}
|
|
|
|
pub async fn start_oauth_device_flow(&self, server: &str) -> Result<DeviceAuthorization> {
|
|
let oauth_config = {
|
|
let config = self.config.lock().await;
|
|
let server_cfg = config
|
|
.effective_mcp_servers()
|
|
.iter()
|
|
.find(|entry| entry.name == server)
|
|
.ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{server}' is configured"))
|
|
})?;
|
|
server_cfg.oauth.clone().ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{server}' does not define an OAuth configuration"
|
|
))
|
|
})?
|
|
};
|
|
|
|
let client = OAuthClient::new(oauth_config)?;
|
|
client.start_device_authorization().await
|
|
}
|
|
|
|
pub async fn poll_oauth_device_flow(
|
|
&mut self,
|
|
server: &str,
|
|
authorization: &DeviceAuthorization,
|
|
) -> Result<DevicePollState> {
|
|
let oauth_config = {
|
|
let config = self.config.lock().await;
|
|
let server_cfg = config
|
|
.effective_mcp_servers()
|
|
.iter()
|
|
.find(|entry| entry.name == server)
|
|
.ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{server}' is configured"))
|
|
})?;
|
|
server_cfg.oauth.clone().ok_or_else(|| {
|
|
Error::Config(format!(
|
|
"MCP server '{server}' does not define an OAuth configuration"
|
|
))
|
|
})?
|
|
};
|
|
|
|
let client = OAuthClient::new(oauth_config)?;
|
|
match client.poll_device_token(authorization).await? {
|
|
DevicePollState::Pending { retry_in } => Ok(DevicePollState::Pending { retry_in }),
|
|
DevicePollState::Complete(token) => {
|
|
let manager = self.credential_manager.as_ref().cloned().ok_or_else(|| {
|
|
Error::Config(
|
|
"OAuth token storage requires encrypted local data; set \
|
|
privacy.encrypt_local_data = true in the configuration."
|
|
.to_string(),
|
|
)
|
|
})?;
|
|
|
|
manager.store_oauth_token(server, &token).await?;
|
|
self.missing_oauth_servers.retain(|entry| entry != server);
|
|
|
|
Ok(DevicePollState::Complete(token))
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn list_mcp_tools(&self) -> Vec<(String, crate::mcp::McpToolDescriptor)> {
|
|
let mut entries = Vec::new();
|
|
for (server, client) in self.named_mcp_clients.iter() {
|
|
let server_name = server.clone();
|
|
let client = Arc::clone(client);
|
|
match client.list_tools().await {
|
|
Ok(tools) => {
|
|
for descriptor in tools {
|
|
entries.push((server_name.clone(), descriptor));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
warn!(
|
|
"Failed to list tools for MCP server '{}': {}",
|
|
server_name, err
|
|
);
|
|
}
|
|
}
|
|
}
|
|
entries
|
|
}
|
|
|
|
pub async fn call_mcp_tool(
|
|
&self,
|
|
server: &str,
|
|
tool: &str,
|
|
arguments: Value,
|
|
) -> Result<crate::mcp::McpToolResponse> {
|
|
let client = self.named_mcp_clients.get(server).cloned().ok_or_else(|| {
|
|
Error::Config(format!("No MCP server named '{}' is registered", server))
|
|
})?;
|
|
client
|
|
.call_tool(McpToolCall {
|
|
name: tool.to_string(),
|
|
arguments,
|
|
})
|
|
.await
|
|
}
|
|
|
|
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
|
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
|
}
|
|
|
|
pub fn storage(&self) -> Arc<StorageManager> {
|
|
self.storage.clone()
|
|
}
|
|
|
|
pub fn master_key(&self) -> Option<Arc<Vec<u8>>> {
|
|
self.master_key.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
pub fn vault(&self) -> Option<Arc<Mutex<VaultHandle>>> {
|
|
self.vault.as_ref().map(Arc::clone)
|
|
}
|
|
|
|
pub async fn read_file(&self, path: &str) -> Result<String> {
|
|
let call = McpToolCall {
|
|
name: "resources_get".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(response) => {
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
return Ok(text);
|
|
}
|
|
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(formatted)
|
|
}
|
|
Err(err) => {
|
|
log::warn!("MCP file read failed ({}); falling back to local read", err);
|
|
let content = std::fs::read_to_string(path)?;
|
|
Ok(content)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn read_file_with_tools(&self, path: &str) -> Result<String> {
|
|
if !self.enable_code_tools {
|
|
return Err(Error::InvalidInput(
|
|
"Code tools are disabled in chat mode. Run `:mode code` to switch.".to_string(),
|
|
));
|
|
}
|
|
|
|
let call = McpToolCall {
|
|
name: "resources_get".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
|
|
let response = self.mcp_client.call_tool(call).await?;
|
|
if let Some(text) = extract_resource_content(&response.output) {
|
|
Ok(text)
|
|
} else {
|
|
let formatted = serde_json::to_string_pretty(&response.output)
|
|
.unwrap_or_else(|_| response.output.to_string());
|
|
Ok(formatted)
|
|
}
|
|
}
|
|
|
|
pub fn code_tools_enabled(&self) -> bool {
|
|
self.enable_code_tools
|
|
}
|
|
|
|
pub async fn set_code_tools_enabled(&mut self, enabled: bool) -> Result<()> {
|
|
if self.enable_code_tools == enabled {
|
|
return Ok(());
|
|
}
|
|
|
|
self.enable_code_tools = enabled;
|
|
self.rebuild_tools().await
|
|
}
|
|
|
|
pub async fn set_operating_mode(&mut self, mode: Mode) -> Result<()> {
|
|
self.current_mode = mode;
|
|
let enable_code_tools = matches!(mode, Mode::Code);
|
|
self.set_code_tools_enabled(enable_code_tools).await?;
|
|
self.mcp_client.set_mode(mode).await
|
|
}
|
|
|
|
pub async fn list_dir(&self, path: &str) -> Result<Vec<String>> {
|
|
let call = McpToolCall {
|
|
name: "resources_list".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(response) => {
|
|
let content: Vec<String> = serde_json::from_value(response.output)?;
|
|
Ok(content)
|
|
}
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP directory list failed ({}); falling back to local list",
|
|
err
|
|
);
|
|
let mut entries = Vec::new();
|
|
for entry in std::fs::read_dir(path)? {
|
|
let entry = entry?;
|
|
entries.push(entry.file_name().to_string_lossy().to_string());
|
|
}
|
|
Ok(entries)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn write_file(&self, path: &str, content: &str) -> Result<()> {
|
|
let call = McpToolCall {
|
|
name: "resources_write".to_string(),
|
|
arguments: serde_json::json!({ "path": path, "content": content }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(_) => Ok(()),
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP file write failed ({}); falling back to local write",
|
|
err
|
|
);
|
|
// Ensure parent directory exists
|
|
if let Some(parent) = std::path::Path::new(path).parent() {
|
|
std::fs::create_dir_all(parent)?;
|
|
}
|
|
std::fs::write(path, content)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn delete_file(&self, path: &str) -> Result<()> {
|
|
let call = McpToolCall {
|
|
name: "resources_delete".to_string(),
|
|
arguments: serde_json::json!({ "path": path }),
|
|
};
|
|
match self.mcp_client.call_tool(call).await {
|
|
Ok(_) => Ok(()),
|
|
Err(err) => {
|
|
log::warn!(
|
|
"MCP file delete failed ({}); falling back to local delete",
|
|
err
|
|
);
|
|
std::fs::remove_file(path)?;
|
|
Ok(())
|
|
}
|
|
}
|
|
}
|
|
|
|
async fn rebuild_tools(&mut self) -> Result<()> {
|
|
let (registry, validator) = build_tools(
|
|
self.config.clone(),
|
|
self.ui.clone(),
|
|
self.enable_code_tools,
|
|
self.consent_manager.clone(),
|
|
self.credential_manager.clone(),
|
|
self.vault.clone(),
|
|
)
|
|
.await?;
|
|
self.tool_registry = registry;
|
|
self.schema_validator = validator;
|
|
|
|
// Recreate MCP client with permission layer
|
|
let config = self.config.lock().await;
|
|
let factory = McpClientFactory::new(
|
|
Arc::new(config.clone()),
|
|
self.tool_registry.clone(),
|
|
self.schema_validator.clone(),
|
|
);
|
|
let base_client = factory.create()?;
|
|
let permission_client = PermissionLayer::new(base_client, Arc::new(config.clone()));
|
|
let client = Arc::new(permission_client);
|
|
client.set_mode(self.current_mode).await?;
|
|
self.mcp_client = client;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn selected_model(&self) -> &str {
|
|
&self.conversation.active().model
|
|
}
|
|
|
|
pub async fn set_model(&mut self, model: String) {
|
|
self.conversation.set_model(model.clone());
|
|
let mut config = self.config.lock().await;
|
|
config.general.default_model = Some(model);
|
|
}
|
|
|
|
pub async fn models(&self, force_refresh: bool) -> Result<Vec<ModelInfo>> {
|
|
self.model_manager
|
|
.get_or_refresh(force_refresh, || async {
|
|
self.provider.list_models().await
|
|
})
|
|
.await
|
|
}
|
|
|
|
fn as_ollama(&self) -> Option<&OllamaProvider> {
|
|
self.provider
|
|
.as_ref()
|
|
.as_any()
|
|
.downcast_ref::<OllamaProvider>()
|
|
}
|
|
|
|
pub async fn model_details(
|
|
&self,
|
|
model_name: &str,
|
|
force_refresh: bool,
|
|
) -> Result<DetailedModelInfo> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
if force_refresh {
|
|
ollama.refresh_model_info(model_name).await
|
|
} else {
|
|
ollama.get_model_info(model_name).await
|
|
}
|
|
} else {
|
|
Err(Error::NotImplemented(format!(
|
|
"Provider '{}' does not expose model inspection",
|
|
self.provider.name()
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub async fn all_model_details(&self, force_refresh: bool) -> Result<Vec<DetailedModelInfo>> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
if force_refresh {
|
|
ollama.clear_model_info_cache().await;
|
|
}
|
|
ollama.get_all_models_info().await
|
|
} else {
|
|
Err(Error::NotImplemented(format!(
|
|
"Provider '{}' does not expose model inspection",
|
|
self.provider.name()
|
|
)))
|
|
}
|
|
}
|
|
|
|
pub async fn cached_model_details(&self) -> Vec<DetailedModelInfo> {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.cached_model_info().await
|
|
} else {
|
|
Vec::new()
|
|
}
|
|
}
|
|
|
|
pub async fn invalidate_model_details(&self, model_name: &str) {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.invalidate_model_info(model_name).await;
|
|
}
|
|
}
|
|
|
|
pub async fn clear_model_details_cache(&self) {
|
|
if let Some(ollama) = self.as_ollama() {
|
|
ollama.clear_model_info_cache().await;
|
|
}
|
|
}
|
|
|
|
pub async fn ensure_default_model(&mut self, models: &[ModelInfo]) {
|
|
let mut config = self.config.lock().await;
|
|
if let Some(default) = config.general.default_model.clone() {
|
|
if models.iter().any(|m| m.id == default || m.name == default) {
|
|
self.conversation.set_model(default.clone());
|
|
config.general.default_model = Some(default);
|
|
}
|
|
} else if let Some(model) = models.first() {
|
|
self.conversation.set_model(model.id.clone());
|
|
config.general.default_model = Some(model.id.clone());
|
|
}
|
|
}
|
|
|
|
pub async fn switch_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
|
|
self.provider = provider;
|
|
self.model_manager.invalidate().await;
|
|
Ok(())
|
|
}
|
|
|
|
/// Expose the underlying LLM provider.
|
|
pub fn provider(&self) -> Arc<dyn Provider> {
|
|
self.provider.clone()
|
|
}
|
|
|
|
pub async fn send_message(
|
|
&mut self,
|
|
content: String,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
|
parameters.stream = streaming;
|
|
self.conversation.push_user_message(content);
|
|
let _ = self.maybe_auto_compress().await?;
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
pub async fn send_request_with_current_conversation(
|
|
&mut self,
|
|
mut parameters: ChatParameters,
|
|
) -> Result<SessionOutcome> {
|
|
let streaming = { self.config.lock().await.general.enable_streaming || parameters.stream };
|
|
parameters.stream = streaming;
|
|
|
|
let tools = if !self.tool_registry.all().is_empty() {
|
|
Some(
|
|
self.tool_registry
|
|
.all()
|
|
.into_iter()
|
|
.map(|tool| crate::mcp::McpToolDescriptor {
|
|
name: tool.name().to_string(),
|
|
description: tool.description().to_string(),
|
|
input_schema: tool.schema(),
|
|
requires_network: tool.requires_network(),
|
|
requires_filesystem: tool.requires_filesystem(),
|
|
})
|
|
.collect(),
|
|
)
|
|
} else {
|
|
None
|
|
};
|
|
|
|
let mut request = ChatRequest {
|
|
model: self.conversation.active().model.clone(),
|
|
messages: self.conversation.active().messages.clone(),
|
|
parameters: parameters.clone(),
|
|
tools: tools.clone(),
|
|
};
|
|
|
|
if !streaming {
|
|
const MAX_TOOL_ITERATIONS: usize = 5;
|
|
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
|
match self.provider.send_prompt(request.clone()).await {
|
|
Ok(response) => {
|
|
if response.message.has_tool_calls() {
|
|
self.conversation.push_message(response.message.clone());
|
|
if let Some(tool_calls) = &response.message.tool_calls {
|
|
for tool_call in tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
let tool_result =
|
|
self.mcp_client.call_tool(mcp_tool_call).await;
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| {
|
|
"Tool execution succeeded".to_string()
|
|
}),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
let tool_msg =
|
|
Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
}
|
|
request.messages = self.conversation.active().messages.clone();
|
|
continue;
|
|
} else {
|
|
if let Some(usage) = response.usage.as_ref() {
|
|
let _ = self.record_usage_sample(usage).await;
|
|
}
|
|
self.conversation.push_message(response.message.clone());
|
|
let _ = self.maybe_auto_compress().await?;
|
|
return Ok(SessionOutcome::Complete(response));
|
|
}
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error: {}", err));
|
|
return Err(err);
|
|
}
|
|
}
|
|
}
|
|
self.conversation
|
|
.push_assistant_message("Maximum tool execution iterations reached".to_string());
|
|
return Err(crate::Error::Provider(anyhow::anyhow!(
|
|
"Maximum tool execution iterations reached"
|
|
)));
|
|
}
|
|
|
|
match self.provider.stream_prompt(request).await {
|
|
Ok(stream) => {
|
|
let response_id = self.conversation.start_streaming_response();
|
|
self.stream_states
|
|
.insert(response_id, StreamingMessageState::new());
|
|
Ok(SessionOutcome::Streaming {
|
|
response_id,
|
|
stream,
|
|
})
|
|
}
|
|
Err(err) => {
|
|
self.conversation
|
|
.push_assistant_message(format!("Error starting stream: {}", err));
|
|
Err(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn mark_stream_placeholder(&mut self, message_id: Uuid, text: &str) -> Result<()> {
|
|
self.conversation
|
|
.set_stream_placeholder(message_id, text.to_string())
|
|
}
|
|
|
|
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
|
let state = self.stream_states.entry(message_id).or_default();
|
|
|
|
let diff = state.ingest(chunk);
|
|
|
|
if let Some(text_delta) = diff.text {
|
|
match text_delta.mode {
|
|
TextDeltaKind::Append => {
|
|
self.conversation.append_stream_chunk(
|
|
message_id,
|
|
&text_delta.content,
|
|
chunk.is_final,
|
|
)?;
|
|
}
|
|
TextDeltaKind::Replace => {
|
|
self.conversation.set_stream_content(
|
|
message_id,
|
|
text_delta.content,
|
|
chunk.is_final,
|
|
)?;
|
|
}
|
|
}
|
|
} else if chunk.is_final {
|
|
self.conversation
|
|
.append_stream_chunk(message_id, "", true)?;
|
|
}
|
|
|
|
if let Some(tool_calls) = diff.tool_calls {
|
|
self.conversation
|
|
.set_tool_calls_on_message(message_id, tool_calls)?;
|
|
}
|
|
|
|
if chunk.is_final {
|
|
state.mark_finished();
|
|
self.stream_states.remove(&message_id);
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub fn check_streaming_tool_calls(&mut self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
|
let maybe_calls = self
|
|
.conversation
|
|
.active()
|
|
.messages
|
|
.iter()
|
|
.find(|m| m.id == message_id)
|
|
.and_then(|m| m.tool_calls.clone())
|
|
.filter(|calls| !calls.is_empty());
|
|
|
|
let calls = maybe_calls?;
|
|
|
|
if !self
|
|
.pending_tool_requests
|
|
.values()
|
|
.any(|pending| pending.message_id == message_id)
|
|
{
|
|
if let Some((tool_name, data_types, endpoints)) =
|
|
self.check_tools_consent_needed(&calls).into_iter().next()
|
|
{
|
|
let request_id = Uuid::new_v4();
|
|
let pending = PendingToolRequest {
|
|
message_id,
|
|
tool_name: tool_name.clone(),
|
|
data_types: data_types.clone(),
|
|
endpoints: endpoints.clone(),
|
|
tool_calls: calls.clone(),
|
|
};
|
|
self.pending_tool_requests.insert(request_id, pending);
|
|
|
|
if let Some(tx) = &self.event_tx {
|
|
let _ = tx.send(ControllerEvent::ToolRequested {
|
|
request_id,
|
|
message_id,
|
|
tool_name,
|
|
data_types,
|
|
endpoints,
|
|
tool_calls: calls.clone(),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
Some(calls)
|
|
}
|
|
|
|
pub fn resolve_tool_consent(
|
|
&mut self,
|
|
request_id: Uuid,
|
|
scope: ConsentScope,
|
|
) -> Result<ToolConsentResolution> {
|
|
let pending = self
|
|
.pending_tool_requests
|
|
.remove(&request_id)
|
|
.ok_or_else(|| {
|
|
Error::InvalidInput(format!("Unknown tool consent request: {}", request_id))
|
|
})?;
|
|
|
|
let PendingToolRequest {
|
|
message_id,
|
|
tool_name,
|
|
data_types,
|
|
endpoints,
|
|
tool_calls,
|
|
..
|
|
} = pending;
|
|
|
|
if !matches!(scope, ConsentScope::Denied) {
|
|
self.grant_consent_with_scope(&tool_name, data_types, endpoints, scope.clone());
|
|
}
|
|
|
|
Ok(ToolConsentResolution {
|
|
request_id,
|
|
message_id,
|
|
tool_name,
|
|
scope,
|
|
tool_calls,
|
|
})
|
|
}
|
|
|
|
pub fn cancel_stream(&mut self, message_id: Uuid, notice: &str) -> Result<()> {
|
|
self.stream_states.remove(&message_id);
|
|
self.conversation
|
|
.cancel_stream(message_id, notice.to_string())
|
|
}
|
|
|
|
pub async fn execute_streaming_tools(
|
|
&mut self,
|
|
_message_id: Uuid,
|
|
tool_calls: Vec<ToolCall>,
|
|
) -> Result<SessionOutcome> {
|
|
for tool_call in &tool_calls {
|
|
let mcp_tool_call = McpToolCall {
|
|
name: tool_call.name.clone(),
|
|
arguments: tool_call.arguments.clone(),
|
|
};
|
|
let tool_result = self.mcp_client.call_tool(mcp_tool_call).await;
|
|
let tool_response_content = match tool_result {
|
|
Ok(result) => serde_json::to_string_pretty(&result.output)
|
|
.unwrap_or_else(|_| "Tool execution succeeded".to_string()),
|
|
Err(e) => format!("Tool execution failed: {}", e),
|
|
};
|
|
let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content);
|
|
self.conversation.push_message(tool_msg);
|
|
}
|
|
let parameters = ChatParameters {
|
|
stream: self.config.lock().await.general.enable_streaming,
|
|
..Default::default()
|
|
};
|
|
self.send_request_with_current_conversation(parameters)
|
|
.await
|
|
}
|
|
|
|
pub fn history(&self) -> Vec<Conversation> {
|
|
self.conversation.history().cloned().collect()
|
|
}
|
|
|
|
pub fn start_new_conversation(&mut self, model: Option<String>, name: Option<String>) {
|
|
self.conversation.start_new(model, name);
|
|
self.stream_states.clear();
|
|
}
|
|
|
|
pub fn clear(&mut self) {
|
|
self.conversation.clear();
|
|
self.stream_states.clear();
|
|
}
|
|
|
|
pub async fn generate_conversation_description(&self) -> Result<String> {
|
|
// ... (implementation remains the same)
|
|
Ok("Empty conversation".to_string())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::Provider;
|
|
use crate::config::{Config, McpMode, McpOAuthConfig, McpServerConfig};
|
|
use crate::llm::test_utils::MockProvider;
|
|
use crate::storage::StorageManager;
|
|
use crate::ui::NoOpUiController;
|
|
use chrono::Utc;
|
|
use httpmock::prelude::*;
|
|
use serde_json::json;
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
use tempfile::tempdir;
|
|
|
|
fn make_response(
|
|
text: &str,
|
|
tool_calls: Option<Vec<ToolCall>>,
|
|
is_final: bool,
|
|
) -> ChatResponse {
|
|
let mut message = Message::assistant(text.to_string());
|
|
message.tool_calls = tool_calls;
|
|
ChatResponse {
|
|
message,
|
|
usage: None,
|
|
is_streaming: true,
|
|
is_final,
|
|
}
|
|
}
|
|
|
|
fn make_tool_call(id: &str, name: &str) -> ToolCall {
|
|
ToolCall {
|
|
id: id.to_string(),
|
|
name: name.to_string(),
|
|
arguments: serde_json::json!({}),
|
|
}
|
|
}
|
|
|
|
const SERVER_NAME: &str = "oauth-test";
|
|
|
|
fn build_oauth_config(server: &MockServer) -> McpOAuthConfig {
|
|
McpOAuthConfig {
|
|
client_id: "owlen-client".to_string(),
|
|
client_secret: None,
|
|
authorize_url: server.url("/authorize"),
|
|
token_url: server.url("/token"),
|
|
device_authorization_url: Some(server.url("/device")),
|
|
redirect_url: None,
|
|
scopes: vec!["repo".to_string()],
|
|
token_env: Some("OAUTH_TOKEN".to_string()),
|
|
header: Some("Authorization".to_string()),
|
|
header_prefix: Some("Bearer ".to_string()),
|
|
}
|
|
}
|
|
|
|
fn build_config(server: &MockServer) -> Config {
|
|
let mut config = Config::default();
|
|
config.mcp.mode = McpMode::LocalOnly;
|
|
let oauth = build_oauth_config(server);
|
|
|
|
let mut env = HashMap::new();
|
|
env.insert("OWLEN_ENV".to_string(), "test".to_string());
|
|
|
|
config.mcp_servers = vec![McpServerConfig {
|
|
name: SERVER_NAME.to_string(),
|
|
command: server.url("/mcp"),
|
|
args: Vec::new(),
|
|
transport: "http".to_string(),
|
|
env,
|
|
oauth: Some(oauth),
|
|
}];
|
|
|
|
config.refresh_mcp_servers(None).unwrap();
|
|
config
|
|
}
|
|
|
|
#[test]
|
|
fn streaming_state_tracks_text_deltas() {
|
|
let mut state = StreamingMessageState::new();
|
|
|
|
let diff = state.ingest(&make_response("Hello", None, false));
|
|
let first = diff.text.expect("text diff");
|
|
assert_eq!(first.content, "Hello");
|
|
assert_eq!(first.mode, TextDeltaKind::Append);
|
|
|
|
let diff = state.ingest(&make_response("Hello world", None, false));
|
|
let second = diff.text.expect("second diff");
|
|
assert_eq!(second.content, " world");
|
|
assert_eq!(second.mode, TextDeltaKind::Append);
|
|
|
|
let diff = state.ingest(&make_response("Hi", None, false));
|
|
let third = diff.text.expect("third diff");
|
|
assert_eq!(third.content, "Hi");
|
|
assert_eq!(third.mode, TextDeltaKind::Replace);
|
|
}
|
|
|
|
#[test]
|
|
fn streaming_state_detects_tool_call_changes() {
|
|
let mut state = StreamingMessageState::new();
|
|
let tool = make_tool_call("call-1", "web_search");
|
|
|
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
|
let calls = diff.tool_calls.expect("initial tool call");
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "web_search");
|
|
|
|
let diff = state.ingest(&make_response("", Some(vec![tool.clone()]), false));
|
|
assert!(
|
|
diff.tool_calls.is_none(),
|
|
"duplicate tool call should not emit"
|
|
);
|
|
|
|
let diff = state.ingest(&make_response("", Some(vec![]), false));
|
|
let cleared = diff.tool_calls.expect("clearing tool calls");
|
|
assert!(cleared.is_empty());
|
|
}
|
|
|
|
async fn build_session(server: &MockServer) -> (SessionController, tempfile::TempDir) {
|
|
unsafe {
|
|
std::env::set_var("OWLEN_MASTER_PASSWORD", "test-password");
|
|
}
|
|
|
|
let temp_dir = tempdir().expect("tempdir");
|
|
let storage_path = temp_dir.path().join("owlen.db");
|
|
let storage = Arc::new(
|
|
StorageManager::with_database_path(storage_path)
|
|
.await
|
|
.expect("storage"),
|
|
);
|
|
|
|
let config = build_config(server);
|
|
let provider: Arc<dyn Provider> = Arc::new(MockProvider::default()) as Arc<dyn Provider>;
|
|
let ui = Arc::new(NoOpUiController);
|
|
|
|
let session = SessionController::new(provider, config, storage, ui, false, None)
|
|
.await
|
|
.expect("session");
|
|
|
|
(session, temp_dir)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn start_oauth_device_flow_returns_details() {
|
|
let server = MockServer::start_async().await;
|
|
let device = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/device");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"device_code": "device-abc",
|
|
"user_code": "ABCD-1234",
|
|
"verification_uri": "https://example.test/activate",
|
|
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-1234",
|
|
"expires_in": 600,
|
|
"interval": 5,
|
|
"message": "Enter the code to continue."
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let (session, _dir) = build_session(&server).await;
|
|
let authorization = session
|
|
.start_oauth_device_flow(SERVER_NAME)
|
|
.await
|
|
.expect("device flow");
|
|
|
|
assert_eq!(authorization.user_code, "ABCD-1234");
|
|
assert_eq!(
|
|
authorization.verification_uri_complete.as_deref(),
|
|
Some("https://example.test/activate?user_code=ABCD-1234")
|
|
);
|
|
assert!(authorization.expires_at > Utc::now());
|
|
device.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn poll_oauth_device_flow_stores_token_and_updates_state() {
|
|
let server = MockServer::start_async().await;
|
|
|
|
let device = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/device");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"device_code": "device-xyz",
|
|
"user_code": "WXYZ-9999",
|
|
"verification_uri": "https://example.test/activate",
|
|
"verification_uri_complete": "https://example.test/activate?user_code=WXYZ-9999",
|
|
"expires_in": 600,
|
|
"interval": 5
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let token = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST)
|
|
.path("/token")
|
|
.body_contains("device_code=device-xyz");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"access_token": "new-access-token",
|
|
"refresh_token": "refresh-token",
|
|
"expires_in": 3600,
|
|
"token_type": "Bearer"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let (mut session, _dir) = build_session(&server).await;
|
|
assert_eq!(session.pending_oauth_servers(), vec![SERVER_NAME]);
|
|
|
|
let authorization = session
|
|
.start_oauth_device_flow(SERVER_NAME)
|
|
.await
|
|
.expect("device flow");
|
|
|
|
match session
|
|
.poll_oauth_device_flow(SERVER_NAME, &authorization)
|
|
.await
|
|
.expect("token poll")
|
|
{
|
|
DevicePollState::Complete(token_info) => {
|
|
assert_eq!(token_info.access_token, "new-access-token");
|
|
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-token"));
|
|
}
|
|
other => panic!("expected token completion, got {other:?}"),
|
|
}
|
|
|
|
assert!(
|
|
session
|
|
.pending_oauth_servers()
|
|
.iter()
|
|
.all(|entry| entry != SERVER_NAME),
|
|
"server should be removed from pending list"
|
|
);
|
|
|
|
let stored = session
|
|
.credential_manager()
|
|
.expect("credential manager")
|
|
.load_oauth_token(SERVER_NAME)
|
|
.await
|
|
.expect("load token")
|
|
.expect("token present");
|
|
|
|
assert_eq!(stored.access_token, "new-access-token");
|
|
assert_eq!(stored.refresh_token.as_deref(), Some("refresh-token"));
|
|
|
|
device.assert_async().await;
|
|
token.assert_async().await;
|
|
}
|
|
}
|