Integrate core functionality for tools, MCP, and enhanced session management
Adds consent management for tool execution, input validation, sandboxed process execution, and MCP server integration. Updates session management to support tool use, conversation persistence, and streaming responses. Major additions: - Database migrations for conversations and secure storage - Encryption and credential management infrastructure - Extensible tool system with code execution and web search - Consent management and validation systems - Sandboxed process execution - MCP server integration Infrastructure changes: - Module registration and workspace dependencies - ToolCall type and tool-related Message methods - Privacy, security, and tool configuration structures - Database-backed conversation persistence - Tool call tracking in conversations Provider and UI updates: - Ollama provider updates for tool support and new Role types - TUI chat and code app updates for async initialization - CLI updates for new SessionController API - Configuration documentation updates - CHANGELOG updates 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -2,7 +2,7 @@
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Arg, Command};
|
||||
use owlen_core::session::SessionController;
|
||||
use owlen_core::{session::SessionController, storage::StorageManager};
|
||||
use owlen_ollama::OllamaProvider;
|
||||
use owlen_tui::{config, ui, AppState, CodeApp, Event, EventHandler, SessionEvent};
|
||||
use std::io;
|
||||
@@ -37,14 +37,27 @@ async fn main() -> Result<()> {
|
||||
config.general.default_model = Some(model.clone());
|
||||
}
|
||||
|
||||
let provider_cfg = config::ensure_ollama_config(&mut config).clone();
|
||||
let provider_name = config.general.default_provider.clone();
|
||||
let provider_cfg = config::ensure_provider_config(&mut config, &provider_name).clone();
|
||||
|
||||
let provider_type = provider_cfg.provider_type.to_ascii_lowercase();
|
||||
if provider_type != "ollama" && provider_type != "ollama-cloud" {
|
||||
anyhow::bail!(
|
||||
"Unsupported provider type '{}' configured for provider '{}'",
|
||||
provider_cfg.provider_type,
|
||||
provider_name
|
||||
);
|
||||
}
|
||||
|
||||
let provider = Arc::new(OllamaProvider::from_config(
|
||||
&provider_cfg,
|
||||
Some(&config.general),
|
||||
)?);
|
||||
|
||||
let controller = SessionController::new(provider, config.clone());
|
||||
let (mut app, mut session_rx) = CodeApp::new(controller);
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
// Code client - code execution tools enabled
|
||||
let controller = SessionController::new(provider, config.clone(), storage.clone(), true)?;
|
||||
let (mut app, mut session_rx) = CodeApp::new(controller).await?;
|
||||
app.inner_mut().initialize_models().await?;
|
||||
|
||||
let cancellation_token = CancellationToken::new();
|
||||
@@ -87,8 +100,21 @@ async fn run_app(
|
||||
session_rx: &mut mpsc::UnboundedReceiver<SessionEvent>,
|
||||
) -> Result<()> {
|
||||
loop {
|
||||
// Advance loading animation frame
|
||||
app.inner_mut().advance_loading_animation();
|
||||
|
||||
terminal.draw(|f| ui::render_chat(f, app.inner_mut()))?;
|
||||
|
||||
// Process any pending LLM requests AFTER UI has been drawn
|
||||
if let Err(e) = app.inner_mut().process_pending_llm_request().await {
|
||||
eprintln!("Error processing LLM request: {}", e);
|
||||
}
|
||||
|
||||
// Process any pending tool executions AFTER UI has been drawn
|
||||
if let Err(e) = app.inner_mut().process_pending_tool_execution().await {
|
||||
eprintln!("Error processing tool execution: {}", e);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
Some(event) = event_rx.recv() => {
|
||||
if let AppState::Quit = app.handle_event(event).await? {
|
||||
@@ -98,6 +124,10 @@ async fn run_app(
|
||||
Some(session_event) = session_rx.recv() => {
|
||||
app.handle_session_event(session_event)?;
|
||||
}
|
||||
// Add a timeout to keep the animation going even when there are no events
|
||||
_ = tokio::time::sleep(tokio::time::Duration::from_millis(100)) => {
|
||||
// This will cause the loop to continue and advance the animation
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::{Arg, Command};
|
||||
use owlen_core::session::SessionController;
|
||||
use owlen_core::{session::SessionController, storage::StorageManager};
|
||||
use owlen_ollama::OllamaProvider;
|
||||
use owlen_tui::{config, ui, AppState, ChatApp, Event, EventHandler, SessionEvent};
|
||||
use std::io;
|
||||
@@ -38,14 +38,27 @@ async fn main() -> Result<()> {
|
||||
}
|
||||
|
||||
// Prepare provider from configuration
|
||||
let provider_cfg = config::ensure_ollama_config(&mut config).clone();
|
||||
let provider_name = config.general.default_provider.clone();
|
||||
let provider_cfg = config::ensure_provider_config(&mut config, &provider_name).clone();
|
||||
|
||||
let provider_type = provider_cfg.provider_type.to_ascii_lowercase();
|
||||
if provider_type != "ollama" && provider_type != "ollama-cloud" {
|
||||
anyhow::bail!(
|
||||
"Unsupported provider type '{}' configured for provider '{}'",
|
||||
provider_cfg.provider_type,
|
||||
provider_name
|
||||
);
|
||||
}
|
||||
|
||||
let provider = Arc::new(OllamaProvider::from_config(
|
||||
&provider_cfg,
|
||||
Some(&config.general),
|
||||
)?);
|
||||
|
||||
let controller = SessionController::new(provider, config.clone());
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller);
|
||||
let storage = Arc::new(StorageManager::new().await?);
|
||||
// Chat client - code execution tools disabled (only available in code client)
|
||||
let controller = SessionController::new(provider, config.clone(), storage.clone(), false)?;
|
||||
let (mut app, mut session_rx) = ChatApp::new(controller).await?;
|
||||
app.initialize_models().await?;
|
||||
|
||||
// Event infrastructure
|
||||
@@ -104,7 +117,14 @@ async fn run_app(
|
||||
terminal.draw(|f| ui::render_chat(f, app))?;
|
||||
|
||||
// Process any pending LLM requests AFTER UI has been drawn
|
||||
app.process_pending_llm_request().await?;
|
||||
if let Err(e) = app.process_pending_llm_request().await {
|
||||
eprintln!("Error processing LLM request: {}", e);
|
||||
}
|
||||
|
||||
// Process any pending tool executions AFTER UI has been drawn
|
||||
if let Err(e) = app.process_pending_tool_execution().await {
|
||||
eprintln!("Error processing tool execution: {}", e);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
Some(event) = event_rx.recv() => {
|
||||
|
||||
@@ -25,7 +25,20 @@ toml = "0.8.0"
|
||||
shellexpand = "3.1.0"
|
||||
dirs = "5.0"
|
||||
ratatui = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
jsonschema = { workspace = true }
|
||||
which = { workspace = true }
|
||||
nix = { workspace = true }
|
||||
aes-gcm = { workspace = true }
|
||||
ring = { workspace = true }
|
||||
keyring = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
urlencoding = { workspace = true }
|
||||
rpassword = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
duckduckgo = "0.2.0"
|
||||
reqwest = { workspace = true, features = ["default"] }
|
||||
reqwest_011 = { version = "0.11", package = "reqwest" }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
|
||||
@@ -26,19 +26,24 @@ pub struct Config {
|
||||
/// Input handling preferences
|
||||
#[serde(default)]
|
||||
pub input: InputSettings,
|
||||
/// Privacy controls for tooling and network usage
|
||||
#[serde(default)]
|
||||
pub privacy: PrivacySettings,
|
||||
/// Security controls for sandboxing and resource limits
|
||||
#[serde(default)]
|
||||
pub security: SecuritySettings,
|
||||
/// Per-tool configuration toggles
|
||||
#[serde(default)]
|
||||
pub tools: ToolSettings,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
let mut providers = HashMap::new();
|
||||
providers.insert("ollama".to_string(), default_ollama_provider_config());
|
||||
providers.insert(
|
||||
"ollama".to_string(),
|
||||
ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
"ollama-cloud".to_string(),
|
||||
default_ollama_cloud_provider_config(),
|
||||
);
|
||||
|
||||
Self {
|
||||
@@ -47,6 +52,9 @@ impl Default for Config {
|
||||
ui: UiSettings::default(),
|
||||
storage: StorageSettings::default(),
|
||||
input: InputSettings::default(),
|
||||
privacy: PrivacySettings::default(),
|
||||
security: SecuritySettings::default(),
|
||||
tools: ToolSettings::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -120,17 +128,26 @@ impl Config {
|
||||
self.general.default_provider = "ollama".to_string();
|
||||
}
|
||||
|
||||
if !self.providers.contains_key("ollama") {
|
||||
self.providers.insert(
|
||||
"ollama".to_string(),
|
||||
ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
);
|
||||
}
|
||||
ensure_provider_config(self, "ollama");
|
||||
ensure_provider_config(self, "ollama-cloud");
|
||||
}
|
||||
}
|
||||
|
||||
fn default_ollama_provider_config() -> ProviderConfig {
|
||||
ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn default_ollama_cloud_provider_config() -> ProviderConfig {
|
||||
ProviderConfig {
|
||||
provider_type: "ollama-cloud".to_string(),
|
||||
base_url: Some("https://ollama.com".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -185,6 +202,154 @@ impl Default for GeneralSettings {
|
||||
}
|
||||
}
|
||||
|
||||
/// Privacy controls governing network access and storage
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PrivacySettings {
|
||||
#[serde(default = "PrivacySettings::default_remote_search")]
|
||||
pub enable_remote_search: bool,
|
||||
#[serde(default)]
|
||||
pub cache_web_results: bool,
|
||||
#[serde(default)]
|
||||
pub retain_history_days: u32,
|
||||
#[serde(default = "PrivacySettings::default_require_consent")]
|
||||
pub require_consent_per_session: bool,
|
||||
#[serde(default = "PrivacySettings::default_encrypt_local_data")]
|
||||
pub encrypt_local_data: bool,
|
||||
}
|
||||
|
||||
impl PrivacySettings {
|
||||
const fn default_remote_search() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
const fn default_require_consent() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
const fn default_encrypt_local_data() -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PrivacySettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_remote_search: Self::default_remote_search(),
|
||||
cache_web_results: false,
|
||||
retain_history_days: 0,
|
||||
require_consent_per_session: Self::default_require_consent(),
|
||||
encrypt_local_data: Self::default_encrypt_local_data(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Security settings that constrain tool execution
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SecuritySettings {
|
||||
#[serde(default = "SecuritySettings::default_enable_sandboxing")]
|
||||
pub enable_sandboxing: bool,
|
||||
#[serde(default = "SecuritySettings::default_timeout")]
|
||||
pub sandbox_timeout_seconds: u64,
|
||||
#[serde(default = "SecuritySettings::default_max_memory")]
|
||||
pub max_memory_mb: u64,
|
||||
#[serde(default = "SecuritySettings::default_allowed_tools")]
|
||||
pub allowed_tools: Vec<String>,
|
||||
}
|
||||
|
||||
impl SecuritySettings {
|
||||
const fn default_enable_sandboxing() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
const fn default_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
|
||||
const fn default_max_memory() -> u64 {
|
||||
512
|
||||
}
|
||||
|
||||
fn default_allowed_tools() -> Vec<String> {
|
||||
vec!["web_search".to_string(), "code_exec".to_string()]
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for SecuritySettings {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enable_sandboxing: Self::default_enable_sandboxing(),
|
||||
sandbox_timeout_seconds: Self::default_timeout(),
|
||||
max_memory_mb: Self::default_max_memory(),
|
||||
allowed_tools: Self::default_allowed_tools(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Per-tool configuration toggles
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
pub struct ToolSettings {
|
||||
#[serde(default)]
|
||||
pub web_search: WebSearchToolConfig,
|
||||
#[serde(default)]
|
||||
pub code_exec: CodeExecToolConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct WebSearchToolConfig {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default)]
|
||||
pub api_key: String,
|
||||
#[serde(default = "WebSearchToolConfig::default_max_results")]
|
||||
pub max_results: u32,
|
||||
}
|
||||
|
||||
impl WebSearchToolConfig {
|
||||
const fn default_max_results() -> u32 {
|
||||
5
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WebSearchToolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
api_key: String::new(),
|
||||
max_results: Self::default_max_results(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CodeExecToolConfig {
|
||||
#[serde(default)]
|
||||
pub enabled: bool,
|
||||
#[serde(default = "CodeExecToolConfig::default_allowed_languages")]
|
||||
pub allowed_languages: Vec<String>,
|
||||
#[serde(default = "CodeExecToolConfig::default_timeout")]
|
||||
pub timeout_seconds: u64,
|
||||
}
|
||||
|
||||
impl CodeExecToolConfig {
|
||||
fn default_allowed_languages() -> Vec<String> {
|
||||
vec!["python".to_string(), "javascript".to_string()]
|
||||
}
|
||||
|
||||
const fn default_timeout() -> u64 {
|
||||
30
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for CodeExecToolConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: false,
|
||||
allowed_languages: Self::default_allowed_languages(),
|
||||
timeout_seconds: Self::default_timeout(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// UI preferences that consumers can respect as needed
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct UiSettings {
|
||||
@@ -343,15 +508,32 @@ impl Default for InputSettings {
|
||||
|
||||
/// Convenience accessor for an Ollama provider entry, creating a default if missing
|
||||
pub fn ensure_ollama_config(config: &mut Config) -> &ProviderConfig {
|
||||
config
|
||||
.providers
|
||||
.entry("ollama".to_string())
|
||||
.or_insert_with(|| ProviderConfig {
|
||||
provider_type: "ollama".to_string(),
|
||||
base_url: Some("http://localhost:11434".to_string()),
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
})
|
||||
ensure_provider_config(config, "ollama")
|
||||
}
|
||||
|
||||
/// Ensure a provider configuration exists for the requested provider name
|
||||
pub fn ensure_provider_config<'a>(
|
||||
config: &'a mut Config,
|
||||
provider_name: &str,
|
||||
) -> &'a ProviderConfig {
|
||||
use std::collections::hash_map::Entry;
|
||||
|
||||
match config.providers.entry(provider_name.to_string()) {
|
||||
Entry::Occupied(entry) => entry.into_mut(),
|
||||
Entry::Vacant(entry) => {
|
||||
let default = match provider_name {
|
||||
"ollama-cloud" => default_ollama_cloud_provider_config(),
|
||||
"ollama" => default_ollama_provider_config(),
|
||||
other => ProviderConfig {
|
||||
provider_type: other.to_string(),
|
||||
base_url: None,
|
||||
api_key: None,
|
||||
extra: HashMap::new(),
|
||||
},
|
||||
};
|
||||
entry.insert(default)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Calculate absolute timeout for session data based on configuration
|
||||
@@ -404,4 +586,21 @@ mod tests {
|
||||
let path = config.storage.conversation_path();
|
||||
assert!(path.to_string_lossy().contains("custom/path"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_config_contains_local_and_cloud_providers() {
|
||||
let config = Config::default();
|
||||
assert!(config.providers.contains_key("ollama"));
|
||||
assert!(config.providers.contains_key("ollama-cloud"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ensure_provider_config_backfills_cloud_defaults() {
|
||||
let mut config = Config::default();
|
||||
config.providers.remove("ollama-cloud");
|
||||
|
||||
let cloud = ensure_provider_config(&mut config, "ollama-cloud");
|
||||
assert_eq!(cloud.provider_type, "ollama-cloud");
|
||||
assert_eq!(cloud.base_url.as_deref(), Some("https://ollama.com"));
|
||||
}
|
||||
}
|
||||
|
||||
172
crates/owlen-core/src/consent.rs
Normal file
172
crates/owlen-core/src/consent.rs
Normal file
@@ -0,0 +1,172 @@
|
||||
use std::collections::HashMap;
|
||||
use std::io::{self, Write};
|
||||
use std::sync::Arc;
|
||||
|
||||
use anyhow::Result;
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::encryption::VaultHandle;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ConsentRequest {
|
||||
pub tool_name: String,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Debug)]
|
||||
pub struct ConsentRecord {
|
||||
pub tool_name: String,
|
||||
pub granted: bool,
|
||||
pub timestamp: DateTime<Utc>,
|
||||
pub data_types: Vec<String>,
|
||||
pub external_endpoints: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
pub struct ConsentManager {
|
||||
records: HashMap<String, ConsentRecord>,
|
||||
}
|
||||
|
||||
impl ConsentManager {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Load consent records from vault storage
|
||||
pub fn from_vault(vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Self {
|
||||
let guard = vault.lock().expect("Vault mutex poisoned");
|
||||
if let Some(consent_data) = guard.settings().get("consent_records") {
|
||||
if let Ok(records) =
|
||||
serde_json::from_value::<HashMap<String, ConsentRecord>>(consent_data.clone())
|
||||
{
|
||||
return Self { records };
|
||||
}
|
||||
}
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Persist consent records to vault storage
|
||||
pub fn persist_to_vault(&self, vault: &Arc<std::sync::Mutex<VaultHandle>>) -> Result<()> {
|
||||
let mut guard = vault.lock().expect("Vault mutex poisoned");
|
||||
let consent_json = serde_json::to_value(&self.records)?;
|
||||
guard
|
||||
.settings_mut()
|
||||
.insert("consent_records".to_string(), consent_json);
|
||||
guard.persist()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn request_consent(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Result<bool> {
|
||||
if let Some(existing) = self.records.get(tool_name) {
|
||||
return Ok(existing.granted);
|
||||
}
|
||||
|
||||
let consent = self.show_consent_dialog(tool_name, &data_types, &endpoints)?;
|
||||
|
||||
let record = ConsentRecord {
|
||||
tool_name: tool_name.to_string(),
|
||||
granted: consent,
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
external_endpoints: endpoints,
|
||||
};
|
||||
|
||||
self.records.insert(tool_name.to_string(), record);
|
||||
// Note: Caller should persist to vault after this call
|
||||
Ok(consent)
|
||||
}
|
||||
|
||||
/// Grant consent programmatically (for TUI or automated flows)
|
||||
pub fn grant_consent(
|
||||
&mut self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) {
|
||||
let record = ConsentRecord {
|
||||
tool_name: tool_name.to_string(),
|
||||
granted: true,
|
||||
timestamp: Utc::now(),
|
||||
data_types,
|
||||
external_endpoints: endpoints,
|
||||
};
|
||||
self.records.insert(tool_name.to_string(), record);
|
||||
}
|
||||
|
||||
/// Check if consent is needed (returns None if already granted, Some(info) if needed)
|
||||
pub fn check_consent_needed(&self, tool_name: &str) -> Option<ConsentRequest> {
|
||||
if self.has_consent(tool_name) {
|
||||
None
|
||||
} else {
|
||||
Some(ConsentRequest {
|
||||
tool_name: tool_name.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_consent(&self, tool_name: &str) -> bool {
|
||||
self.records
|
||||
.get(tool_name)
|
||||
.map(|record| record.granted)
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
pub fn revoke_consent(&mut self, tool_name: &str) {
|
||||
if let Some(record) = self.records.get_mut(tool_name) {
|
||||
record.granted = false;
|
||||
record.timestamp = Utc::now();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn clear_all_consent(&mut self) {
|
||||
self.records.clear();
|
||||
}
|
||||
|
||||
/// Check if consent is needed for a tool (non-blocking)
|
||||
/// Returns Some with consent details if needed, None if already granted
|
||||
pub fn check_if_consent_needed(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
data_types: Vec<String>,
|
||||
endpoints: Vec<String>,
|
||||
) -> Option<(String, Vec<String>, Vec<String>)> {
|
||||
if self.has_consent(tool_name) {
|
||||
return None;
|
||||
}
|
||||
Some((tool_name.to_string(), data_types, endpoints))
|
||||
}
|
||||
|
||||
fn show_consent_dialog(
|
||||
&self,
|
||||
tool_name: &str,
|
||||
data_types: &[String],
|
||||
endpoints: &[String],
|
||||
) -> Result<bool> {
|
||||
// TEMPORARY: Auto-grant consent when not in a proper terminal (TUI mode)
|
||||
// TODO: Integrate consent UI into the TUI event loop
|
||||
use std::io::IsTerminal;
|
||||
if !io::stdin().is_terminal() || std::env::var("OWLEN_AUTO_CONSENT").is_ok() {
|
||||
eprintln!("Auto-granting consent for {} (TUI mode)", tool_name);
|
||||
return Ok(true);
|
||||
}
|
||||
|
||||
println!("=== PRIVACY CONSENT REQUIRED ===");
|
||||
println!("Tool: {}", tool_name);
|
||||
println!("Data to be sent: {}", data_types.join(", "));
|
||||
println!("External endpoints: {}", endpoints.join(", "));
|
||||
println!("Do you consent to this data transmission? (y/N)");
|
||||
|
||||
print!("> ");
|
||||
io::stdout().flush()?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
|
||||
Ok(matches!(input.trim().to_lowercase().as_str(), "y" | "yes"))
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,6 @@ use crate::types::{Conversation, Message};
|
||||
use crate::Result;
|
||||
use serde_json::{Number, Value};
|
||||
use std::collections::{HashMap, VecDeque};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::{Duration, Instant};
|
||||
use uuid::Uuid;
|
||||
|
||||
@@ -214,6 +213,25 @@ impl ConversationManager {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Set tool calls on a streaming message
|
||||
pub fn set_tool_calls_on_message(
|
||||
&mut self,
|
||||
message_id: Uuid,
|
||||
tool_calls: Vec<crate::types::ToolCall>,
|
||||
) -> Result<()> {
|
||||
let index = self
|
||||
.message_index
|
||||
.get(&message_id)
|
||||
.copied()
|
||||
.ok_or_else(|| crate::Error::Unknown(format!("Unknown message id: {message_id}")))?;
|
||||
|
||||
if let Some(message) = self.active_mut().messages.get_mut(index) {
|
||||
message.tool_calls = Some(tool_calls);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Update the active model (used when user changes model mid session)
|
||||
pub fn set_model(&mut self, model: impl Into<String>) {
|
||||
self.active.model = model.into();
|
||||
@@ -268,36 +286,40 @@ impl ConversationManager {
|
||||
}
|
||||
|
||||
/// Save the active conversation to disk
|
||||
pub fn save_active(&self, storage: &StorageManager, name: Option<String>) -> Result<PathBuf> {
|
||||
storage.save_conversation(&self.active, name)
|
||||
pub async fn save_active(
|
||||
&self,
|
||||
storage: &StorageManager,
|
||||
name: Option<String>,
|
||||
) -> Result<Uuid> {
|
||||
storage.save_conversation(&self.active, name).await?;
|
||||
Ok(self.active.id)
|
||||
}
|
||||
|
||||
/// Save the active conversation to disk with a description
|
||||
pub fn save_active_with_description(
|
||||
pub async fn save_active_with_description(
|
||||
&self,
|
||||
storage: &StorageManager,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<PathBuf> {
|
||||
storage.save_conversation_with_description(&self.active, name, description)
|
||||
) -> Result<Uuid> {
|
||||
storage
|
||||
.save_conversation_with_description(&self.active, name, description)
|
||||
.await?;
|
||||
Ok(self.active.id)
|
||||
}
|
||||
|
||||
/// Load a conversation from disk and make it active
|
||||
pub fn load_from_disk(
|
||||
&mut self,
|
||||
storage: &StorageManager,
|
||||
path: impl AsRef<Path>,
|
||||
) -> Result<()> {
|
||||
let conversation = storage.load_conversation(path)?;
|
||||
/// Load a conversation from storage and make it active
|
||||
pub async fn load_saved(&mut self, storage: &StorageManager, id: Uuid) -> Result<()> {
|
||||
let conversation = storage.load_conversation(id).await?;
|
||||
self.load(conversation);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// List all saved sessions
|
||||
pub fn list_saved_sessions(
|
||||
pub async fn list_saved_sessions(
|
||||
storage: &StorageManager,
|
||||
) -> Result<Vec<crate::storage::SessionMeta>> {
|
||||
storage.list_sessions()
|
||||
storage.list_sessions().await
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,28 +4,42 @@
|
||||
//! LLM providers, routers, and MCP (Model Context Protocol) adapters.
|
||||
|
||||
pub mod config;
|
||||
pub mod consent;
|
||||
pub mod conversation;
|
||||
pub mod credentials;
|
||||
pub mod encryption;
|
||||
pub mod formatting;
|
||||
pub mod input;
|
||||
pub mod mcp;
|
||||
pub mod model;
|
||||
pub mod provider;
|
||||
pub mod router;
|
||||
pub mod sandbox;
|
||||
pub mod session;
|
||||
pub mod storage;
|
||||
pub mod theme;
|
||||
pub mod tools;
|
||||
pub mod types;
|
||||
pub mod ui;
|
||||
pub mod validation;
|
||||
pub mod wrap_cursor;
|
||||
|
||||
pub use config::*;
|
||||
pub use consent::*;
|
||||
pub use conversation::*;
|
||||
pub use credentials::*;
|
||||
pub use encryption::*;
|
||||
pub use formatting::*;
|
||||
pub use input::*;
|
||||
pub use mcp::*;
|
||||
pub use model::*;
|
||||
pub use provider::*;
|
||||
pub use router::*;
|
||||
pub use sandbox::*;
|
||||
pub use session::*;
|
||||
pub use theme::*;
|
||||
pub use tools::*;
|
||||
pub use validation::*;
|
||||
|
||||
/// Result type used throughout the OWLEN ecosystem
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
82
crates/owlen-core/src/mcp/mod.rs
Normal file
82
crates/owlen-core/src/mcp/mod.rs
Normal file
@@ -0,0 +1,82 @@
|
||||
use crate::tools::registry::ToolRegistry;
|
||||
use crate::validation::SchemaValidator;
|
||||
use crate::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
/// Descriptor for a tool exposed over MCP
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolDescriptor {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub input_schema: Value,
|
||||
pub requires_network: bool,
|
||||
pub requires_filesystem: Vec<String>,
|
||||
}
|
||||
|
||||
/// Invocation payload for a tool call
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolCall {
|
||||
pub name: String,
|
||||
pub arguments: Value,
|
||||
}
|
||||
|
||||
/// Result returned by a tool invocation
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct McpToolResponse {
|
||||
pub name: String,
|
||||
pub success: bool,
|
||||
pub output: Value,
|
||||
pub metadata: HashMap<String, String>,
|
||||
pub duration_ms: u128,
|
||||
}
|
||||
|
||||
/// Thin MCP server facade over the tool registry
|
||||
pub struct McpServer {
|
||||
registry: Arc<ToolRegistry>,
|
||||
validator: Arc<SchemaValidator>,
|
||||
}
|
||||
|
||||
impl McpServer {
|
||||
pub fn new(registry: Arc<ToolRegistry>, validator: Arc<SchemaValidator>) -> Self {
|
||||
Self {
|
||||
registry,
|
||||
validator,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enumerate the registered tools as MCP descriptors
|
||||
pub fn list_tools(&self) -> Vec<McpToolDescriptor> {
|
||||
self.registry
|
||||
.all()
|
||||
.into_iter()
|
||||
.map(|tool| McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.schema(),
|
||||
requires_network: tool.requires_network(),
|
||||
requires_filesystem: tool.requires_filesystem(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Execute a tool call after validating inputs against the registered schema
|
||||
pub async fn call_tool(&self, call: McpToolCall) -> Result<McpToolResponse> {
|
||||
self.validator.validate(&call.name, &call.arguments)?;
|
||||
let result = self.registry.execute(&call.name, call.arguments).await?;
|
||||
Ok(McpToolResponse {
|
||||
name: call.name,
|
||||
success: result.success,
|
||||
output: result.output,
|
||||
metadata: result.metadata,
|
||||
duration_ms: duration_to_millis(result.duration),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn duration_to_millis(duration: Duration) -> u128 {
|
||||
duration.as_secs() as u128 * 1_000 + u128::from(duration.subsec_millis())
|
||||
}
|
||||
212
crates/owlen-core/src/sandbox.rs
Normal file
212
crates/owlen-core/src/sandbox.rs
Normal file
@@ -0,0 +1,212 @@
|
||||
use std::path::PathBuf;
|
||||
use std::process::{Command, Stdio};
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use anyhow::{bail, Context, Result};
|
||||
use tempfile::TempDir;
|
||||
|
||||
/// Configuration options for sandboxed process execution.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct SandboxConfig {
|
||||
pub allow_network: bool,
|
||||
pub allow_paths: Vec<PathBuf>,
|
||||
pub readonly_paths: Vec<PathBuf>,
|
||||
pub timeout_seconds: u64,
|
||||
pub max_memory_mb: u64,
|
||||
}
|
||||
|
||||
impl Default for SandboxConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
allow_network: false,
|
||||
allow_paths: Vec::new(),
|
||||
readonly_paths: Vec::new(),
|
||||
timeout_seconds: 30,
|
||||
max_memory_mb: 512,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper around a bubblewrap sandbox instance.
|
||||
///
|
||||
/// Memory limits are enforced via:
|
||||
/// - bwrap's --rlimit-as (version >= 0.12.0)
|
||||
/// - prlimit wrapper (fallback for older bwrap versions)
|
||||
/// - timeout mechanism (always enforced as last resort)
|
||||
pub struct SandboxedProcess {
|
||||
temp_dir: TempDir,
|
||||
config: SandboxConfig,
|
||||
}
|
||||
|
||||
impl SandboxedProcess {
|
||||
pub fn new(config: SandboxConfig) -> Result<Self> {
|
||||
let temp_dir = TempDir::new().context("Failed to create temp directory")?;
|
||||
|
||||
which::which("bwrap")
|
||||
.context("bubblewrap not found. Install with: sudo apt install bubblewrap")?;
|
||||
|
||||
Ok(Self { temp_dir, config })
|
||||
}
|
||||
|
||||
pub fn execute(&self, command: &str, args: &[&str]) -> Result<SandboxResult> {
|
||||
let supports_rlimit = self.supports_rlimit_as();
|
||||
let use_prlimit = !supports_rlimit && which::which("prlimit").is_ok();
|
||||
|
||||
let mut cmd = if use_prlimit {
|
||||
// Use prlimit wrapper for older bwrap versions
|
||||
let mut prlimit_cmd = Command::new("prlimit");
|
||||
let memory_limit_bytes = self
|
||||
.config
|
||||
.max_memory_mb
|
||||
.saturating_mul(1024)
|
||||
.saturating_mul(1024);
|
||||
prlimit_cmd.arg(format!("--as={}", memory_limit_bytes));
|
||||
prlimit_cmd.arg("bwrap");
|
||||
prlimit_cmd
|
||||
} else {
|
||||
Command::new("bwrap")
|
||||
};
|
||||
|
||||
cmd.args(["--unshare-all", "--die-with-parent", "--new-session"]);
|
||||
|
||||
if self.config.allow_network {
|
||||
cmd.arg("--share-net");
|
||||
} else {
|
||||
cmd.arg("--unshare-net");
|
||||
}
|
||||
|
||||
cmd.args(["--proc", "/proc", "--dev", "/dev", "--tmpfs", "/tmp"]);
|
||||
|
||||
// Bind essential system paths readonly for executables and libraries
|
||||
let system_paths = ["/usr", "/bin", "/lib", "/lib64", "/etc"];
|
||||
for sys_path in &system_paths {
|
||||
let path = std::path::Path::new(sys_path);
|
||||
if path.exists() {
|
||||
cmd.arg("--ro-bind").arg(sys_path).arg(sys_path);
|
||||
}
|
||||
}
|
||||
|
||||
// Bind /run for DNS resolution (resolv.conf may be a symlink to /run/systemd/resolve/*)
|
||||
if std::path::Path::new("/run").exists() {
|
||||
cmd.arg("--ro-bind").arg("/run").arg("/run");
|
||||
}
|
||||
|
||||
for path in &self.config.allow_paths {
|
||||
let path_host = path.to_string_lossy().into_owned();
|
||||
let path_guest = path_host.clone();
|
||||
cmd.arg("--bind").arg(&path_host).arg(&path_guest);
|
||||
}
|
||||
|
||||
for path in &self.config.readonly_paths {
|
||||
let path_host = path.to_string_lossy().into_owned();
|
||||
let path_guest = path_host.clone();
|
||||
cmd.arg("--ro-bind").arg(&path_host).arg(&path_guest);
|
||||
}
|
||||
|
||||
let work_dir = self.temp_dir.path().to_string_lossy().into_owned();
|
||||
cmd.arg("--bind").arg(&work_dir).arg("/work");
|
||||
cmd.arg("--chdir").arg("/work");
|
||||
|
||||
// Add memory limits via bwrap's --rlimit-as if supported (version >= 0.12.0)
|
||||
// If not supported, we use prlimit wrapper (set earlier)
|
||||
if supports_rlimit && !use_prlimit {
|
||||
let memory_limit_bytes = self
|
||||
.config
|
||||
.max_memory_mb
|
||||
.saturating_mul(1024)
|
||||
.saturating_mul(1024);
|
||||
let memory_soft = memory_limit_bytes.to_string();
|
||||
let memory_hard = memory_limit_bytes.to_string();
|
||||
cmd.arg("--rlimit-as").arg(&memory_soft).arg(&memory_hard);
|
||||
}
|
||||
|
||||
cmd.arg(command);
|
||||
cmd.args(args);
|
||||
|
||||
let start = Instant::now();
|
||||
let timeout = Duration::from_secs(self.config.timeout_seconds);
|
||||
|
||||
// Spawn the process instead of waiting immediately
|
||||
let mut child = cmd
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.spawn()
|
||||
.context("Failed to spawn sandboxed command")?;
|
||||
|
||||
let mut was_timeout = false;
|
||||
|
||||
// Wait for the child with timeout
|
||||
let output = loop {
|
||||
match child.try_wait() {
|
||||
Ok(Some(_status)) => {
|
||||
// Process exited
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.context("Failed to collect process output")?;
|
||||
break output;
|
||||
}
|
||||
Ok(None) => {
|
||||
// Process still running, check timeout
|
||||
if start.elapsed() >= timeout {
|
||||
// Timeout exceeded, kill the process
|
||||
was_timeout = true;
|
||||
child.kill().context("Failed to kill timed-out process")?;
|
||||
// Wait for the killed process to exit
|
||||
let output = child
|
||||
.wait_with_output()
|
||||
.context("Failed to collect output from killed process")?;
|
||||
break output;
|
||||
}
|
||||
// Sleep briefly before checking again
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
Err(e) => {
|
||||
bail!("Failed to check process status: {}", e);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let duration = start.elapsed();
|
||||
|
||||
Ok(SandboxResult {
|
||||
stdout: String::from_utf8_lossy(&output.stdout).to_string(),
|
||||
stderr: String::from_utf8_lossy(&output.stderr).to_string(),
|
||||
exit_code: output.status.code().unwrap_or(-1),
|
||||
duration,
|
||||
was_timeout,
|
||||
})
|
||||
}
|
||||
|
||||
/// Check if bubblewrap supports --rlimit-as option (version >= 0.12.0)
|
||||
fn supports_rlimit_as(&self) -> bool {
|
||||
// Try to get bwrap version
|
||||
let output = Command::new("bwrap").arg("--version").output();
|
||||
|
||||
if let Ok(output) = output {
|
||||
let version_str = String::from_utf8_lossy(&output.stdout);
|
||||
// Parse version like "bubblewrap 0.11.0" or "0.11.0"
|
||||
if let Some(version_part) = version_str.split_whitespace().last() {
|
||||
if let Some((major, rest)) = version_part.split_once('.') {
|
||||
if let Some((minor, _patch)) = rest.split_once('.') {
|
||||
if let (Ok(maj), Ok(min)) = (major.parse::<u32>(), minor.parse::<u32>()) {
|
||||
// --rlimit-as was added in 0.12.0
|
||||
return maj > 0 || (maj == 0 && min >= 12);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If we can't determine the version, assume it doesn't support it (safer default)
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SandboxResult {
|
||||
pub stdout: String,
|
||||
pub stderr: String,
|
||||
pub exit_code: i32,
|
||||
pub duration: Duration,
|
||||
pub was_timeout: bool,
|
||||
}
|
||||
@@ -1,12 +1,26 @@
|
||||
use crate::config::Config;
|
||||
use crate::consent::ConsentManager;
|
||||
use crate::conversation::ConversationManager;
|
||||
use crate::credentials::CredentialManager;
|
||||
use crate::encryption::{self, VaultHandle};
|
||||
use crate::formatting::MessageFormatter;
|
||||
use crate::input::InputBuffer;
|
||||
use crate::model::ModelManager;
|
||||
use crate::provider::{ChatStream, Provider};
|
||||
use crate::types::{ChatParameters, ChatRequest, ChatResponse, Conversation, ModelInfo};
|
||||
use crate::Result;
|
||||
use std::sync::Arc;
|
||||
use crate::storage::{SessionMeta, StorageManager};
|
||||
use crate::tools::{
|
||||
code_exec::CodeExecTool, registry::ToolRegistry, web_search::WebSearchTool,
|
||||
web_search_detailed::WebSearchDetailedTool, Tool,
|
||||
};
|
||||
use crate::types::{
|
||||
ChatParameters, ChatRequest, ChatResponse, Conversation, Message, ModelInfo, ToolCall,
|
||||
};
|
||||
use crate::validation::{get_builtin_schemas, SchemaValidator};
|
||||
use crate::{Error, Result};
|
||||
use log::warn;
|
||||
use std::env;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Outcome of submitting a chat request
|
||||
@@ -31,6 +45,7 @@ pub enum SessionOutcome {
|
||||
/// use owlen_core::config::Config;
|
||||
/// use owlen_core::provider::{Provider, ChatStream};
|
||||
/// use owlen_core::session::{SessionController, SessionOutcome};
|
||||
/// use owlen_core::storage::StorageManager;
|
||||
/// use owlen_core::types::{ChatRequest, ChatResponse, ChatParameters, Message, ModelInfo};
|
||||
/// use owlen_core::Result;
|
||||
///
|
||||
@@ -55,7 +70,9 @@ pub enum SessionOutcome {
|
||||
/// async fn main() {
|
||||
/// let provider = Arc::new(MockProvider);
|
||||
/// let config = Config::default();
|
||||
/// let mut session_controller = SessionController::new(provider, config);
|
||||
/// let storage = Arc::new(StorageManager::new().await.unwrap());
|
||||
/// let enable_code_tools = false; // Set to true for code client
|
||||
/// let mut session_controller = SessionController::new(provider, config, storage, enable_code_tools).unwrap();
|
||||
///
|
||||
/// // Send a message
|
||||
/// let outcome = session_controller.send_message(
|
||||
@@ -82,17 +99,69 @@ pub struct SessionController {
|
||||
input_buffer: InputBuffer,
|
||||
formatter: MessageFormatter,
|
||||
config: Config,
|
||||
consent_manager: Arc<Mutex<ConsentManager>>,
|
||||
tool_registry: Arc<ToolRegistry>,
|
||||
schema_validator: Arc<SchemaValidator>,
|
||||
storage: Arc<StorageManager>,
|
||||
vault: Option<Arc<Mutex<VaultHandle>>>,
|
||||
master_key: Option<Arc<Vec<u8>>>,
|
||||
credential_manager: Option<Arc<CredentialManager>>,
|
||||
enable_code_tools: bool, // Whether to enable code execution tools (code client only)
|
||||
}
|
||||
|
||||
impl SessionController {
|
||||
/// Create a new controller with the given provider and configuration
|
||||
pub fn new(provider: Arc<dyn Provider>, config: Config) -> Self {
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `provider` - The LLM provider to use
|
||||
/// * `config` - Application configuration
|
||||
/// * `storage` - Storage manager for persistence
|
||||
/// * `enable_code_tools` - Whether to enable code execution tools (should only be true for code client)
|
||||
pub fn new(
|
||||
provider: Arc<dyn Provider>,
|
||||
config: Config,
|
||||
storage: Arc<StorageManager>,
|
||||
enable_code_tools: bool,
|
||||
) -> Result<Self> {
|
||||
let model = config
|
||||
.general
|
||||
.default_model
|
||||
.clone()
|
||||
.unwrap_or_else(|| "ollama/default".to_string());
|
||||
|
||||
let mut vault_handle: Option<Arc<Mutex<VaultHandle>>> = None;
|
||||
let mut master_key: Option<Arc<Vec<u8>>> = None;
|
||||
let mut credential_manager: Option<Arc<CredentialManager>> = None;
|
||||
|
||||
if config.privacy.encrypt_local_data {
|
||||
let base_dir = storage
|
||||
.database_path()
|
||||
.parent()
|
||||
.map(|p| p.to_path_buf())
|
||||
.or_else(dirs::data_local_dir)
|
||||
.unwrap_or_else(|| PathBuf::from("."));
|
||||
let secure_path = base_dir.join("encrypted_data.json");
|
||||
|
||||
let handle = match env::var("OWLEN_MASTER_PASSWORD") {
|
||||
Ok(password) if !password.is_empty() => {
|
||||
encryption::unlock_with_password(secure_path, &password)?
|
||||
}
|
||||
_ => encryption::unlock_interactive(secure_path)?,
|
||||
};
|
||||
|
||||
let master = Arc::new(handle.data.master_key.clone());
|
||||
master_key = Some(master.clone());
|
||||
vault_handle = Some(Arc::new(Mutex::new(handle)));
|
||||
credential_manager = Some(Arc::new(CredentialManager::new(storage.clone(), master)));
|
||||
}
|
||||
|
||||
// Load consent manager from vault if available, otherwise create new
|
||||
let consent_manager = if let Some(ref vault) = vault_handle {
|
||||
Arc::new(Mutex::new(ConsentManager::from_vault(vault)))
|
||||
} else {
|
||||
Arc::new(Mutex::new(ConsentManager::new()))
|
||||
};
|
||||
|
||||
let conversation =
|
||||
ConversationManager::with_history_capacity(model, config.storage.max_saved_sessions);
|
||||
let formatter =
|
||||
@@ -106,14 +175,26 @@ impl SessionController {
|
||||
|
||||
let model_manager = ModelManager::new(config.general.model_cache_ttl());
|
||||
|
||||
Self {
|
||||
let mut controller = Self {
|
||||
provider,
|
||||
conversation,
|
||||
model_manager,
|
||||
input_buffer,
|
||||
formatter,
|
||||
config,
|
||||
}
|
||||
consent_manager,
|
||||
tool_registry: Arc::new(ToolRegistry::new()),
|
||||
schema_validator: Arc::new(SchemaValidator::new()),
|
||||
storage,
|
||||
vault: vault_handle,
|
||||
master_key,
|
||||
credential_manager,
|
||||
enable_code_tools,
|
||||
};
|
||||
|
||||
controller.rebuild_tools()?;
|
||||
|
||||
Ok(controller)
|
||||
}
|
||||
|
||||
/// Access the active conversation
|
||||
@@ -156,6 +237,260 @@ impl SessionController {
|
||||
&mut self.config
|
||||
}
|
||||
|
||||
/// Grant consent programmatically for a tool (for TUI consent dialog)
|
||||
pub fn grant_consent(&self, tool_name: &str, data_types: Vec<String>, endpoints: Vec<String>) {
|
||||
let mut consent = self
|
||||
.consent_manager
|
||||
.lock()
|
||||
.expect("Consent manager mutex poisoned");
|
||||
consent.grant_consent(tool_name, data_types, endpoints);
|
||||
|
||||
// Persist to vault if available
|
||||
if let Some(vault) = &self.vault {
|
||||
if let Err(e) = consent.persist_to_vault(vault) {
|
||||
eprintln!("Warning: Failed to persist consent to vault: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if consent is needed for tool calls (non-blocking check)
|
||||
/// Returns a list of (tool_name, data_types, endpoints) tuples for tools that need consent
|
||||
pub fn check_tools_consent_needed(
|
||||
&self,
|
||||
tool_calls: &[ToolCall],
|
||||
) -> Vec<(String, Vec<String>, Vec<String>)> {
|
||||
let consent = self
|
||||
.consent_manager
|
||||
.lock()
|
||||
.expect("Consent manager mutex poisoned");
|
||||
let mut needs_consent = Vec::new();
|
||||
let mut seen_tools = std::collections::HashSet::new();
|
||||
|
||||
for tool_call in tool_calls {
|
||||
// Skip if we already checked this tool
|
||||
if seen_tools.contains(&tool_call.name) {
|
||||
continue;
|
||||
}
|
||||
seen_tools.insert(tool_call.name.clone());
|
||||
|
||||
// Get tool metadata (data types and endpoints) based on tool name
|
||||
let (data_types, endpoints) = match tool_call.name.as_str() {
|
||||
"web_search" | "web_search_detailed" => (
|
||||
vec!["search query".to_string()],
|
||||
vec!["duckduckgo.com".to_string()],
|
||||
),
|
||||
"code_exec" => (
|
||||
vec!["code to execute".to_string()],
|
||||
vec!["local sandbox".to_string()],
|
||||
),
|
||||
_ => (vec![], vec![]),
|
||||
};
|
||||
|
||||
if let Some((tool_name, dt, ep)) =
|
||||
consent.check_if_consent_needed(&tool_call.name, data_types, endpoints)
|
||||
{
|
||||
needs_consent.push((tool_name, dt, ep));
|
||||
}
|
||||
}
|
||||
|
||||
needs_consent
|
||||
}
|
||||
|
||||
/// Persist the active conversation to storage
|
||||
pub async fn save_active_session(
|
||||
&self,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<Uuid> {
|
||||
self.conversation
|
||||
.save_active_with_description(&self.storage, name, description)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Persist the active conversation without description override
|
||||
pub async fn save_active_session_simple(&self, name: Option<String>) -> Result<Uuid> {
|
||||
self.conversation.save_active(&self.storage, name).await
|
||||
}
|
||||
|
||||
/// Load a saved conversation by ID and make it active
|
||||
pub async fn load_saved_session(&mut self, id: Uuid) -> Result<()> {
|
||||
self.conversation.load_saved(&self.storage, id).await
|
||||
}
|
||||
|
||||
/// Retrieve session metadata from storage
|
||||
pub async fn list_saved_sessions(&self) -> Result<Vec<SessionMeta>> {
|
||||
ConversationManager::list_saved_sessions(&self.storage).await
|
||||
}
|
||||
|
||||
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
||||
self.storage.delete_session(id).await
|
||||
}
|
||||
|
||||
pub async fn clear_secure_data(&self) -> Result<()> {
|
||||
self.storage.clear_secure_items().await?;
|
||||
if let Some(vault) = &self.vault {
|
||||
let mut guard = vault.lock().expect("Vault mutex poisoned");
|
||||
guard.data.settings.clear();
|
||||
guard.persist()?;
|
||||
}
|
||||
// Also clear consent records
|
||||
{
|
||||
let mut consent = self
|
||||
.consent_manager
|
||||
.lock()
|
||||
.expect("Consent manager mutex poisoned");
|
||||
consent.clear_all_consent();
|
||||
}
|
||||
self.persist_consent()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Persist current consent state to vault (if encryption is enabled)
|
||||
pub fn persist_consent(&self) -> Result<()> {
|
||||
if let Some(vault) = &self.vault {
|
||||
let consent = self
|
||||
.consent_manager
|
||||
.lock()
|
||||
.expect("Consent manager mutex poisoned");
|
||||
consent.persist_to_vault(vault)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_tool_enabled(&mut self, tool: &str, enabled: bool) -> Result<()> {
|
||||
match tool {
|
||||
"web_search" => {
|
||||
self.config.tools.web_search.enabled = enabled;
|
||||
self.config.privacy.enable_remote_search = enabled;
|
||||
}
|
||||
"code_exec" => {
|
||||
self.config.tools.code_exec.enabled = enabled;
|
||||
}
|
||||
other => {
|
||||
return Err(Error::InvalidInput(format!("Unknown tool: {other}")));
|
||||
}
|
||||
}
|
||||
|
||||
self.rebuild_tools()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Access the consent manager shared across tools
|
||||
pub fn consent_manager(&self) -> Arc<Mutex<ConsentManager>> {
|
||||
self.consent_manager.clone()
|
||||
}
|
||||
|
||||
/// Access the tool registry for executing registered tools
|
||||
pub fn tool_registry(&self) -> Arc<ToolRegistry> {
|
||||
Arc::clone(&self.tool_registry)
|
||||
}
|
||||
|
||||
/// Access the schema validator used for tool input validation
|
||||
pub fn schema_validator(&self) -> Arc<SchemaValidator> {
|
||||
Arc::clone(&self.schema_validator)
|
||||
}
|
||||
|
||||
/// Construct an MCP server facade for the active tool registry
|
||||
pub fn mcp_server(&self) -> crate::mcp::McpServer {
|
||||
crate::mcp::McpServer::new(self.tool_registry(), self.schema_validator())
|
||||
}
|
||||
|
||||
/// Access the underlying storage manager
|
||||
pub fn storage(&self) -> Arc<StorageManager> {
|
||||
Arc::clone(&self.storage)
|
||||
}
|
||||
|
||||
/// Retrieve the active master key if encryption is enabled
|
||||
pub fn master_key(&self) -> Option<Arc<Vec<u8>>> {
|
||||
self.master_key.as_ref().map(Arc::clone)
|
||||
}
|
||||
|
||||
/// Access the vault handle for managing secure settings
|
||||
pub fn vault(&self) -> Option<Arc<Mutex<VaultHandle>>> {
|
||||
self.vault.as_ref().map(Arc::clone)
|
||||
}
|
||||
|
||||
/// Access the credential manager if available
|
||||
pub fn credential_manager(&self) -> Option<Arc<CredentialManager>> {
|
||||
self.credential_manager.as_ref().map(Arc::clone)
|
||||
}
|
||||
|
||||
fn rebuild_tools(&mut self) -> Result<()> {
|
||||
let mut registry = ToolRegistry::new();
|
||||
let mut validator = SchemaValidator::new();
|
||||
|
||||
for (name, schema) in get_builtin_schemas() {
|
||||
if let Err(err) = validator.register_schema(&name, schema) {
|
||||
warn!("Failed to register built-in schema {name}: {err}");
|
||||
}
|
||||
}
|
||||
|
||||
if self
|
||||
.config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search")
|
||||
&& self.config.tools.web_search.enabled
|
||||
&& self.config.privacy.enable_remote_search
|
||||
{
|
||||
let tool = WebSearchTool::new(
|
||||
self.consent_manager.clone(),
|
||||
self.credential_manager.clone(),
|
||||
self.vault.clone(),
|
||||
);
|
||||
let schema = tool.schema();
|
||||
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
||||
warn!("Failed to register schema for {}: {err}", tool.name());
|
||||
}
|
||||
registry.register(tool);
|
||||
}
|
||||
|
||||
// Register web_search_detailed tool (provides snippets)
|
||||
if self
|
||||
.config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "web_search") // Same permission as web_search
|
||||
&& self.config.tools.web_search.enabled
|
||||
&& self.config.privacy.enable_remote_search
|
||||
{
|
||||
let tool = WebSearchDetailedTool::new(
|
||||
self.consent_manager.clone(),
|
||||
self.credential_manager.clone(),
|
||||
self.vault.clone(),
|
||||
);
|
||||
let schema = tool.schema();
|
||||
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
||||
warn!("Failed to register schema for {}: {err}", tool.name());
|
||||
}
|
||||
registry.register(tool);
|
||||
}
|
||||
|
||||
// Code execution tool - only available in code client
|
||||
if self.enable_code_tools
|
||||
&& self
|
||||
.config
|
||||
.security
|
||||
.allowed_tools
|
||||
.iter()
|
||||
.any(|tool| tool == "code_exec")
|
||||
&& self.config.tools.code_exec.enabled
|
||||
{
|
||||
let tool = CodeExecTool::new(self.config.tools.code_exec.allowed_languages.clone());
|
||||
let schema = tool.schema();
|
||||
if let Err(err) = validator.register_schema(tool.name(), schema) {
|
||||
warn!("Failed to register schema for {}: {err}", tool.name());
|
||||
}
|
||||
registry.register(tool);
|
||||
}
|
||||
|
||||
self.tool_registry = Arc::new(registry);
|
||||
self.schema_validator = Arc::new(validator);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Currently selected model identifier
|
||||
pub fn selected_model(&self) -> &str {
|
||||
&self.conversation.active().model
|
||||
@@ -187,6 +522,13 @@ impl SessionController {
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace the active provider at runtime and invalidate cached model listings
|
||||
pub async fn switch_provider(&mut self, provider: Arc<dyn Provider>) -> Result<()> {
|
||||
self.provider = provider;
|
||||
self.model_manager.invalidate().await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Submit a user message; optionally stream the response
|
||||
pub async fn send_message(
|
||||
&mut self,
|
||||
@@ -210,38 +552,104 @@ impl SessionController {
|
||||
let streaming = parameters.stream || self.config.general.enable_streaming;
|
||||
parameters.stream = streaming;
|
||||
|
||||
let request = ChatRequest {
|
||||
model: self.conversation.active().model.clone(),
|
||||
messages: self.conversation.active().messages.clone(),
|
||||
parameters,
|
||||
// Get available tools
|
||||
let tools = if !self.tool_registry.all().is_empty() {
|
||||
Some(
|
||||
self.tool_registry
|
||||
.all()
|
||||
.into_iter()
|
||||
.map(|tool| crate::mcp::McpToolDescriptor {
|
||||
name: tool.name().to_string(),
|
||||
description: tool.description().to_string(),
|
||||
input_schema: tool.schema(),
|
||||
requires_network: tool.requires_network(),
|
||||
requires_filesystem: tool.requires_filesystem(),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if streaming {
|
||||
match self.provider.chat_stream(request).await {
|
||||
Ok(stream) => {
|
||||
let response_id = self.conversation.start_streaming_response();
|
||||
Ok(SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
self.conversation
|
||||
.push_assistant_message(format!("Error starting stream: {}", err));
|
||||
Err(err)
|
||||
let mut request = ChatRequest {
|
||||
model: self.conversation.active().model.clone(),
|
||||
messages: self.conversation.active().messages.clone(),
|
||||
parameters: parameters.clone(),
|
||||
tools: tools.clone(),
|
||||
};
|
||||
|
||||
// Tool execution loop (non-streaming only for now)
|
||||
if !streaming {
|
||||
const MAX_TOOL_ITERATIONS: usize = 5;
|
||||
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
||||
match self.provider.chat(request.clone()).await {
|
||||
Ok(response) => {
|
||||
// Check if the response has tool calls
|
||||
if response.message.has_tool_calls() {
|
||||
// Add assistant's tool call message to conversation
|
||||
self.conversation.push_message(response.message.clone());
|
||||
|
||||
// Execute each tool call
|
||||
if let Some(tool_calls) = &response.message.tool_calls {
|
||||
for tool_call in tool_calls {
|
||||
let tool_result = self
|
||||
.tool_registry
|
||||
.execute(&tool_call.name, tool_call.arguments.clone())
|
||||
.await;
|
||||
|
||||
let tool_response_content = match tool_result {
|
||||
Ok(result) => serde_json::to_string_pretty(&result.output)
|
||||
.unwrap_or_else(|_| {
|
||||
"Tool execution succeeded".to_string()
|
||||
}),
|
||||
Err(e) => format!("Tool execution failed: {}", e),
|
||||
};
|
||||
|
||||
// Add tool response to conversation
|
||||
let tool_msg =
|
||||
Message::tool(tool_call.id.clone(), tool_response_content);
|
||||
self.conversation.push_message(tool_msg);
|
||||
}
|
||||
}
|
||||
|
||||
// Update request with new messages for next iteration
|
||||
request.messages = self.conversation.active().messages.clone();
|
||||
continue;
|
||||
} else {
|
||||
// No more tool calls, return final response
|
||||
self.conversation.push_message(response.message.clone());
|
||||
return Ok(SessionOutcome::Complete(response));
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
self.conversation
|
||||
.push_assistant_message(format!("Error: {}", err));
|
||||
return Err(err);
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match self.provider.chat(request).await {
|
||||
Ok(response) => {
|
||||
self.conversation.push_message(response.message.clone());
|
||||
Ok(SessionOutcome::Complete(response))
|
||||
}
|
||||
Err(err) => {
|
||||
self.conversation
|
||||
.push_assistant_message(format!("Error: {}", err));
|
||||
Err(err)
|
||||
}
|
||||
|
||||
// Max iterations reached
|
||||
self.conversation
|
||||
.push_assistant_message("Maximum tool execution iterations reached".to_string());
|
||||
return Err(crate::Error::Provider(anyhow::anyhow!(
|
||||
"Maximum tool execution iterations reached"
|
||||
)));
|
||||
}
|
||||
|
||||
// Streaming mode with tool support
|
||||
match self.provider.chat_stream(request).await {
|
||||
Ok(stream) => {
|
||||
let response_id = self.conversation.start_streaming_response();
|
||||
Ok(SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
})
|
||||
}
|
||||
Err(err) => {
|
||||
self.conversation
|
||||
.push_assistant_message(format!("Error starting stream: {}", err));
|
||||
Err(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -254,10 +662,64 @@ impl SessionController {
|
||||
|
||||
/// Apply streaming chunk to the conversation
|
||||
pub fn apply_stream_chunk(&mut self, message_id: Uuid, chunk: &ChatResponse) -> Result<()> {
|
||||
// Check if this chunk contains tool calls
|
||||
if chunk.message.has_tool_calls() {
|
||||
// This is a tool call chunk - store the tool calls on the message
|
||||
self.conversation.set_tool_calls_on_message(
|
||||
message_id,
|
||||
chunk.message.tool_calls.clone().unwrap_or_default(),
|
||||
)?;
|
||||
}
|
||||
|
||||
self.conversation
|
||||
.append_stream_chunk(message_id, &chunk.message.content, chunk.is_final)
|
||||
}
|
||||
|
||||
/// Check if a streaming message has complete tool calls that need execution
|
||||
pub fn check_streaming_tool_calls(&self, message_id: Uuid) -> Option<Vec<ToolCall>> {
|
||||
self.conversation
|
||||
.active()
|
||||
.messages
|
||||
.iter()
|
||||
.find(|m| m.id == message_id)
|
||||
.and_then(|m| m.tool_calls.clone())
|
||||
.filter(|calls| !calls.is_empty())
|
||||
}
|
||||
|
||||
/// Execute tools for a streaming response and continue conversation
|
||||
pub async fn execute_streaming_tools(
|
||||
&mut self,
|
||||
_message_id: Uuid,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
) -> Result<SessionOutcome> {
|
||||
// Execute each tool call
|
||||
for tool_call in &tool_calls {
|
||||
let tool_result = self
|
||||
.tool_registry
|
||||
.execute(&tool_call.name, tool_call.arguments.clone())
|
||||
.await;
|
||||
|
||||
let tool_response_content = match tool_result {
|
||||
Ok(result) => serde_json::to_string_pretty(&result.output)
|
||||
.unwrap_or_else(|_| "Tool execution succeeded".to_string()),
|
||||
Err(e) => format!("Tool execution failed: {}", e),
|
||||
};
|
||||
|
||||
// Add tool response to conversation
|
||||
let tool_msg = Message::tool(tool_call.id.clone(), tool_response_content);
|
||||
self.conversation.push_message(tool_msg);
|
||||
}
|
||||
|
||||
// Continue the conversation with tool results
|
||||
let parameters = ChatParameters {
|
||||
stream: self.config.general.enable_streaming,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
self.send_request_with_current_conversation(parameters)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Access conversation history
|
||||
pub fn history(&self) -> Vec<Conversation> {
|
||||
self.conversation.history().cloned().collect()
|
||||
@@ -335,6 +797,7 @@ impl SessionController {
|
||||
stream: false,
|
||||
extra: std::collections::HashMap::new(),
|
||||
},
|
||||
tools: None,
|
||||
};
|
||||
|
||||
// Get the summary from the provider
|
||||
|
||||
@@ -1,19 +1,26 @@
|
||||
//! Session persistence and storage management
|
||||
//! Session persistence and storage management backed by SQLite
|
||||
|
||||
use crate::types::Conversation;
|
||||
use crate::{Error, Result};
|
||||
use aes_gcm::aead::{Aead, KeyInit};
|
||||
use aes_gcm::{Aes256Gcm, Nonce};
|
||||
use ring::rand::{SecureRandom, SystemRandom};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode, SqlitePoolOptions, SqliteSynchronous};
|
||||
use sqlx::{Pool, Row, Sqlite};
|
||||
use std::fs;
|
||||
use std::io::IsTerminal;
|
||||
use std::io::{self, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::time::SystemTime;
|
||||
use std::str::FromStr;
|
||||
use std::time::{Duration, SystemTime, UNIX_EPOCH};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Metadata about a saved session
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SessionMeta {
|
||||
/// Session file path
|
||||
pub path: PathBuf,
|
||||
/// Conversation ID
|
||||
pub id: uuid::Uuid,
|
||||
pub id: Uuid,
|
||||
/// Optional session name
|
||||
pub name: Option<String>,
|
||||
/// Optional AI-generated description
|
||||
@@ -28,282 +35,525 @@ pub struct SessionMeta {
|
||||
pub updated_at: SystemTime,
|
||||
}
|
||||
|
||||
/// Storage manager for persisting conversations
|
||||
/// Storage manager for persisting conversations in SQLite
|
||||
pub struct StorageManager {
|
||||
sessions_dir: PathBuf,
|
||||
pool: Pool<Sqlite>,
|
||||
database_path: PathBuf,
|
||||
}
|
||||
|
||||
impl StorageManager {
|
||||
/// Create a new storage manager with the default sessions directory
|
||||
pub fn new() -> Result<Self> {
|
||||
let sessions_dir = Self::default_sessions_dir()?;
|
||||
Self::with_directory(sessions_dir)
|
||||
/// Create a new storage manager using the default database path
|
||||
pub async fn new() -> Result<Self> {
|
||||
let db_path = Self::default_database_path()?;
|
||||
Self::with_database_path(db_path).await
|
||||
}
|
||||
|
||||
/// Create a storage manager with a custom sessions directory
|
||||
pub fn with_directory(sessions_dir: PathBuf) -> Result<Self> {
|
||||
// Ensure the directory exists
|
||||
if !sessions_dir.exists() {
|
||||
fs::create_dir_all(&sessions_dir).map_err(|e| {
|
||||
Error::Storage(format!("Failed to create sessions directory: {}", e))
|
||||
})?;
|
||||
/// Create a storage manager using the provided database path
|
||||
pub async fn with_database_path(database_path: PathBuf) -> Result<Self> {
|
||||
if let Some(parent) = database_path.parent() {
|
||||
if !parent.exists() {
|
||||
std::fs::create_dir_all(parent).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to create database directory {parent:?}: {e}"
|
||||
))
|
||||
})?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Self { sessions_dir })
|
||||
let options = SqliteConnectOptions::from_str(&format!(
|
||||
"sqlite://{}",
|
||||
database_path
|
||||
.to_str()
|
||||
.ok_or_else(|| Error::Storage("Invalid database path".to_string()))?
|
||||
))
|
||||
.map_err(|e| Error::Storage(format!("Invalid database URL: {e}")))?
|
||||
.create_if_missing(true)
|
||||
.journal_mode(SqliteJournalMode::Wal)
|
||||
.synchronous(SqliteSynchronous::Normal);
|
||||
|
||||
let pool = SqlitePoolOptions::new()
|
||||
.max_connections(5)
|
||||
.connect_with(options)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to connect to database: {e}")))?;
|
||||
|
||||
sqlx::migrate!("./migrations")
|
||||
.run(&pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to run database migrations: {e}")))?;
|
||||
|
||||
let storage = Self {
|
||||
pool,
|
||||
database_path,
|
||||
};
|
||||
|
||||
storage.try_migrate_legacy_sessions().await?;
|
||||
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
/// Get the default sessions directory
|
||||
/// - Linux: ~/.local/share/owlen/sessions
|
||||
/// - Windows: %APPDATA%\owlen\sessions
|
||||
/// - macOS: ~/Library/Application Support/owlen/sessions
|
||||
pub fn default_sessions_dir() -> Result<PathBuf> {
|
||||
/// Save a conversation. Existing entries are updated in-place.
|
||||
pub async fn save_conversation(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
) -> Result<()> {
|
||||
self.save_conversation_with_description(conversation, name, None)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Save a conversation with an optional description override
|
||||
pub async fn save_conversation_with_description(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<()> {
|
||||
let mut serialized = conversation.clone();
|
||||
if name.is_some() {
|
||||
serialized.name = name.clone();
|
||||
}
|
||||
if description.is_some() {
|
||||
serialized.description = description.clone();
|
||||
}
|
||||
|
||||
let data = serde_json::to_string(&serialized)
|
||||
.map_err(|e| Error::Storage(format!("Failed to serialize conversation: {e}")))?;
|
||||
|
||||
let created_at = to_epoch_seconds(serialized.created_at);
|
||||
let updated_at = to_epoch_seconds(serialized.updated_at);
|
||||
let message_count = serialized.messages.len() as i64;
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO conversations (
|
||||
id,
|
||||
name,
|
||||
description,
|
||||
model,
|
||||
message_count,
|
||||
created_at,
|
||||
updated_at,
|
||||
data
|
||||
) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(id) DO UPDATE SET
|
||||
name = excluded.name,
|
||||
description = excluded.description,
|
||||
model = excluded.model,
|
||||
message_count = excluded.message_count,
|
||||
created_at = excluded.created_at,
|
||||
updated_at = excluded.updated_at,
|
||||
data = excluded.data
|
||||
"#,
|
||||
)
|
||||
.bind(serialized.id.to_string())
|
||||
.bind(name.or(serialized.name.clone()))
|
||||
.bind(description.or(serialized.description.clone()))
|
||||
.bind(&serialized.model)
|
||||
.bind(message_count)
|
||||
.bind(created_at)
|
||||
.bind(updated_at)
|
||||
.bind(data)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to save conversation: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load a conversation by ID
|
||||
pub async fn load_conversation(&self, id: Uuid) -> Result<Conversation> {
|
||||
let record = sqlx::query(r#"SELECT data FROM conversations WHERE id = ?1"#)
|
||||
.bind(id.to_string())
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to load conversation: {e}")))?;
|
||||
|
||||
let row =
|
||||
record.ok_or_else(|| Error::Storage(format!("No conversation found with id {id}")))?;
|
||||
|
||||
let data: String = row
|
||||
.try_get("data")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read conversation payload: {e}")))?;
|
||||
|
||||
serde_json::from_str(&data)
|
||||
.map_err(|e| Error::Storage(format!("Failed to deserialize conversation: {e}")))
|
||||
}
|
||||
|
||||
/// List metadata for all saved conversations ordered by most recent update
|
||||
pub async fn list_sessions(&self) -> Result<Vec<SessionMeta>> {
|
||||
let rows = sqlx::query(
|
||||
r#"
|
||||
SELECT id, name, description, model, message_count, created_at, updated_at
|
||||
FROM conversations
|
||||
ORDER BY updated_at DESC
|
||||
"#,
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to list sessions: {e}")))?;
|
||||
|
||||
let mut sessions = Vec::with_capacity(rows.len());
|
||||
for row in rows {
|
||||
let id_text: String = row
|
||||
.try_get("id")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read id column: {e}")))?;
|
||||
let id = Uuid::parse_str(&id_text)
|
||||
.map_err(|e| Error::Storage(format!("Invalid UUID in storage: {e}")))?;
|
||||
|
||||
let message_count: i64 = row
|
||||
.try_get("message_count")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read message count: {e}")))?;
|
||||
|
||||
let created_at: i64 = row
|
||||
.try_get("created_at")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read created_at: {e}")))?;
|
||||
let updated_at: i64 = row
|
||||
.try_get("updated_at")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read updated_at: {e}")))?;
|
||||
|
||||
sessions.push(SessionMeta {
|
||||
id,
|
||||
name: row
|
||||
.try_get("name")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read name: {e}")))?,
|
||||
description: row
|
||||
.try_get("description")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read description: {e}")))?,
|
||||
model: row
|
||||
.try_get("model")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read model: {e}")))?,
|
||||
message_count: message_count as usize,
|
||||
created_at: from_epoch_seconds(created_at),
|
||||
updated_at: from_epoch_seconds(updated_at),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
/// Delete a conversation by ID
|
||||
pub async fn delete_session(&self, id: Uuid) -> Result<()> {
|
||||
sqlx::query("DELETE FROM conversations WHERE id = ?1")
|
||||
.bind(id.to_string())
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to delete conversation: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn store_secure_item(
|
||||
&self,
|
||||
key: &str,
|
||||
plaintext: &[u8],
|
||||
master_key: &[u8],
|
||||
) -> Result<()> {
|
||||
let cipher = create_cipher(master_key)?;
|
||||
let nonce_bytes = generate_nonce()?;
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let ciphertext = cipher
|
||||
.encrypt(nonce, plaintext)
|
||||
.map_err(|e| Error::Storage(format!("Failed to encrypt secure item: {e}")))?;
|
||||
|
||||
let now = to_epoch_seconds(SystemTime::now());
|
||||
|
||||
sqlx::query(
|
||||
r#"
|
||||
INSERT INTO secure_items (key, nonce, ciphertext, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
nonce = excluded.nonce,
|
||||
ciphertext = excluded.ciphertext,
|
||||
updated_at = excluded.updated_at
|
||||
"#,
|
||||
)
|
||||
.bind(key)
|
||||
.bind(&nonce_bytes[..])
|
||||
.bind(&ciphertext[..])
|
||||
.bind(now)
|
||||
.bind(now)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to store secure item: {e}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn load_secure_item(&self, key: &str, master_key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
let record = sqlx::query("SELECT nonce, ciphertext FROM secure_items WHERE key = ?1")
|
||||
.bind(key)
|
||||
.fetch_optional(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to load secure item: {e}")))?;
|
||||
|
||||
let Some(row) = record else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let nonce_bytes: Vec<u8> = row
|
||||
.try_get("nonce")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read secure item nonce: {e}")))?;
|
||||
let ciphertext: Vec<u8> = row
|
||||
.try_get("ciphertext")
|
||||
.map_err(|e| Error::Storage(format!("Failed to read secure item ciphertext: {e}")))?;
|
||||
|
||||
if nonce_bytes.len() != 12 {
|
||||
return Err(Error::Storage(
|
||||
"Invalid nonce length for secure item".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let cipher = create_cipher(master_key)?;
|
||||
let nonce = Nonce::from_slice(&nonce_bytes);
|
||||
let plaintext = cipher
|
||||
.decrypt(nonce, ciphertext.as_ref())
|
||||
.map_err(|e| Error::Storage(format!("Failed to decrypt secure item: {e}")))?;
|
||||
|
||||
Ok(Some(plaintext))
|
||||
}
|
||||
|
||||
pub async fn delete_secure_item(&self, key: &str) -> Result<()> {
|
||||
sqlx::query("DELETE FROM secure_items WHERE key = ?1")
|
||||
.bind(key)
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to delete secure item: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn clear_secure_items(&self) -> Result<()> {
|
||||
sqlx::query("DELETE FROM secure_items")
|
||||
.execute(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to clear secure items: {e}")))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Database location used by this storage manager
|
||||
pub fn database_path(&self) -> &Path {
|
||||
&self.database_path
|
||||
}
|
||||
|
||||
/// Determine default database path (platform specific)
|
||||
pub fn default_database_path() -> Result<PathBuf> {
|
||||
let data_dir = dirs::data_local_dir()
|
||||
.ok_or_else(|| Error::Storage("Could not determine data directory".to_string()))?;
|
||||
Ok(data_dir.join("owlen").join("owlen.db"))
|
||||
}
|
||||
|
||||
fn legacy_sessions_dir() -> Result<PathBuf> {
|
||||
let data_dir = dirs::data_local_dir()
|
||||
.ok_or_else(|| Error::Storage("Could not determine data directory".to_string()))?;
|
||||
Ok(data_dir.join("owlen").join("sessions"))
|
||||
}
|
||||
|
||||
/// Save a conversation to disk
|
||||
pub fn save_conversation(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
) -> Result<PathBuf> {
|
||||
self.save_conversation_with_description(conversation, name, None)
|
||||
async fn database_has_records(&self) -> Result<bool> {
|
||||
let (count,): (i64,) = sqlx::query_as("SELECT COUNT(*) FROM conversations")
|
||||
.fetch_one(&self.pool)
|
||||
.await
|
||||
.map_err(|e| Error::Storage(format!("Failed to inspect database: {e}")))?;
|
||||
Ok(count > 0)
|
||||
}
|
||||
|
||||
/// Save a conversation to disk with an optional description
|
||||
pub fn save_conversation_with_description(
|
||||
&self,
|
||||
conversation: &Conversation,
|
||||
name: Option<String>,
|
||||
description: Option<String>,
|
||||
) -> Result<PathBuf> {
|
||||
let filename = if let Some(ref session_name) = name {
|
||||
// Use provided name, sanitized
|
||||
let sanitized = sanitize_filename(session_name);
|
||||
format!("{}_{}.json", conversation.id, sanitized)
|
||||
} else {
|
||||
// Use conversation ID and timestamp
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs();
|
||||
format!("{}_{}.json", conversation.id, timestamp)
|
||||
async fn try_migrate_legacy_sessions(&self) -> Result<()> {
|
||||
if self.database_has_records().await? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let legacy_dir = match Self::legacy_sessions_dir() {
|
||||
Ok(dir) => dir,
|
||||
Err(_) => return Ok(()),
|
||||
};
|
||||
|
||||
let path = self.sessions_dir.join(filename);
|
||||
|
||||
// Create a saveable version with the name and description
|
||||
let mut save_conv = conversation.clone();
|
||||
if name.is_some() {
|
||||
save_conv.name = name;
|
||||
}
|
||||
if description.is_some() {
|
||||
save_conv.description = description;
|
||||
if !legacy_dir.exists() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let json = serde_json::to_string_pretty(&save_conv)
|
||||
.map_err(|e| Error::Storage(format!("Failed to serialize conversation: {}", e)))?;
|
||||
|
||||
fs::write(&path, json)
|
||||
.map_err(|e| Error::Storage(format!("Failed to write session file: {}", e)))?;
|
||||
|
||||
Ok(path)
|
||||
}
|
||||
|
||||
/// Load a conversation from disk
|
||||
pub fn load_conversation(&self, path: impl AsRef<Path>) -> Result<Conversation> {
|
||||
let content = fs::read_to_string(path.as_ref())
|
||||
.map_err(|e| Error::Storage(format!("Failed to read session file: {}", e)))?;
|
||||
|
||||
let conversation: Conversation = serde_json::from_str(&content)
|
||||
.map_err(|e| Error::Storage(format!("Failed to parse session file: {}", e)))?;
|
||||
|
||||
Ok(conversation)
|
||||
}
|
||||
|
||||
/// List all saved sessions with metadata
|
||||
pub fn list_sessions(&self) -> Result<Vec<SessionMeta>> {
|
||||
let mut sessions = Vec::new();
|
||||
|
||||
let entries = fs::read_dir(&self.sessions_dir)
|
||||
.map_err(|e| Error::Storage(format!("Failed to read sessions directory: {}", e)))?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry
|
||||
.map_err(|e| Error::Storage(format!("Failed to read directory entry: {}", e)))?;
|
||||
let entries = fs::read_dir(&legacy_dir).map_err(|e| {
|
||||
Error::Storage(format!("Failed to read legacy sessions directory: {e}"))
|
||||
})?;
|
||||
|
||||
let mut json_files = Vec::new();
|
||||
for entry in entries.flatten() {
|
||||
let path = entry.path();
|
||||
if path.extension().and_then(|s| s.to_str()) != Some("json") {
|
||||
continue;
|
||||
if path.extension().and_then(|s| s.to_str()) == Some("json") {
|
||||
json_files.push(path);
|
||||
}
|
||||
}
|
||||
|
||||
// Try to load the conversation to extract metadata
|
||||
match self.load_conversation(&path) {
|
||||
Ok(conv) => {
|
||||
sessions.push(SessionMeta {
|
||||
path: path.clone(),
|
||||
id: conv.id,
|
||||
name: conv.name.clone(),
|
||||
description: conv.description.clone(),
|
||||
message_count: conv.messages.len(),
|
||||
model: conv.model.clone(),
|
||||
created_at: conv.created_at,
|
||||
updated_at: conv.updated_at,
|
||||
});
|
||||
}
|
||||
Err(_) => {
|
||||
// Skip files that can't be parsed
|
||||
continue;
|
||||
if json_files.is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if !io::stdin().is_terminal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!(
|
||||
"Legacy OWLEN session files were found in {}.",
|
||||
legacy_dir.display()
|
||||
);
|
||||
if !prompt_yes_no("Migrate them to the new SQLite storage? (y/N) ")? {
|
||||
println!("Skipping legacy session migration.");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
println!("Migrating legacy sessions...");
|
||||
let mut migrated = 0usize;
|
||||
for path in &json_files {
|
||||
match fs::read_to_string(path) {
|
||||
Ok(content) => match serde_json::from_str::<Conversation>(&content) {
|
||||
Ok(conversation) => {
|
||||
if let Err(err) = self
|
||||
.save_conversation_with_description(
|
||||
&conversation,
|
||||
conversation.name.clone(),
|
||||
conversation.description.clone(),
|
||||
)
|
||||
.await
|
||||
{
|
||||
println!(" • Failed to migrate {}: {}", path.display(), err);
|
||||
} else {
|
||||
migrated += 1;
|
||||
}
|
||||
}
|
||||
Err(err) => {
|
||||
println!(
|
||||
" • Failed to parse conversation {}: {}",
|
||||
path.display(),
|
||||
err
|
||||
);
|
||||
}
|
||||
},
|
||||
Err(err) => {
|
||||
println!(" • Failed to read {}: {}", path.display(), err);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by updated_at, most recent first
|
||||
sessions.sort_by(|a, b| b.updated_at.cmp(&a.updated_at));
|
||||
|
||||
Ok(sessions)
|
||||
}
|
||||
|
||||
/// Delete a saved session
|
||||
pub fn delete_session(&self, path: impl AsRef<Path>) -> Result<()> {
|
||||
fs::remove_file(path.as_ref())
|
||||
.map_err(|e| Error::Storage(format!("Failed to delete session file: {}", e)))
|
||||
}
|
||||
|
||||
/// Get the sessions directory path
|
||||
pub fn sessions_dir(&self) -> &Path {
|
||||
&self.sessions_dir
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for StorageManager {
|
||||
fn default() -> Self {
|
||||
Self::new().expect("Failed to create default storage manager")
|
||||
}
|
||||
}
|
||||
|
||||
/// Sanitize a filename by removing invalid characters
|
||||
fn sanitize_filename(name: &str) -> String {
|
||||
name.chars()
|
||||
.map(|c| {
|
||||
if c.is_alphanumeric() || c == '_' || c == '-' {
|
||||
c
|
||||
} else if c.is_whitespace() {
|
||||
'_'
|
||||
} else {
|
||||
'-'
|
||||
if migrated > 0 {
|
||||
if let Err(err) = archive_legacy_directory(&legacy_dir) {
|
||||
println!(
|
||||
"Warning: migrated sessions but failed to archive legacy directory: {}",
|
||||
err
|
||||
);
|
||||
}
|
||||
})
|
||||
.collect::<String>()
|
||||
.chars()
|
||||
.take(50) // Limit length
|
||||
.collect()
|
||||
}
|
||||
|
||||
println!("Migrated {} legacy sessions.", migrated);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn to_epoch_seconds(time: SystemTime) -> i64 {
|
||||
match time.duration_since(UNIX_EPOCH) {
|
||||
Ok(duration) => duration.as_secs() as i64,
|
||||
Err(_) => 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_epoch_seconds(seconds: i64) -> SystemTime {
|
||||
UNIX_EPOCH + Duration::from_secs(seconds.max(0) as u64)
|
||||
}
|
||||
|
||||
fn prompt_yes_no(prompt: &str) -> Result<bool> {
|
||||
print!("{}", prompt);
|
||||
io::stdout()
|
||||
.flush()
|
||||
.map_err(|e| Error::Storage(format!("Failed to flush stdout: {e}")))?;
|
||||
|
||||
let mut input = String::new();
|
||||
io::stdin()
|
||||
.read_line(&mut input)
|
||||
.map_err(|e| Error::Storage(format!("Failed to read input: {e}")))?;
|
||||
let trimmed = input.trim().to_lowercase();
|
||||
Ok(matches!(trimmed.as_str(), "y" | "yes"))
|
||||
}
|
||||
|
||||
fn archive_legacy_directory(legacy_dir: &Path) -> Result<()> {
|
||||
let mut backup_dir = legacy_dir.with_file_name("sessions_legacy_backup");
|
||||
let mut counter = 1;
|
||||
while backup_dir.exists() {
|
||||
backup_dir = legacy_dir.with_file_name(format!("sessions_legacy_backup_{}", counter));
|
||||
counter += 1;
|
||||
}
|
||||
|
||||
fs::rename(legacy_dir, &backup_dir).map_err(|e| {
|
||||
Error::Storage(format!(
|
||||
"Failed to archive legacy sessions directory {}: {}",
|
||||
legacy_dir.display(),
|
||||
e
|
||||
))
|
||||
})?;
|
||||
|
||||
println!("Legacy session files archived to {}", backup_dir.display());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn create_cipher(master_key: &[u8]) -> Result<Aes256Gcm> {
|
||||
if master_key.len() != 32 {
|
||||
return Err(Error::Storage(
|
||||
"Master key must be 32 bytes for AES-256-GCM".to_string(),
|
||||
));
|
||||
}
|
||||
Aes256Gcm::new_from_slice(master_key).map_err(|_| {
|
||||
Error::Storage("Failed to initialize cipher with provided master key".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
fn generate_nonce() -> Result<[u8; 12]> {
|
||||
let mut nonce = [0u8; 12];
|
||||
SystemRandom::new()
|
||||
.fill(&mut nonce)
|
||||
.map_err(|_| Error::Storage("Failed to generate nonce".to_string()))?;
|
||||
Ok(nonce)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::Message;
|
||||
use tempfile::TempDir;
|
||||
use crate::types::{Conversation, Message};
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_platform_specific_default_path() {
|
||||
let path = StorageManager::default_sessions_dir().unwrap();
|
||||
|
||||
// Verify it contains owlen/sessions
|
||||
assert!(path.to_string_lossy().contains("owlen"));
|
||||
assert!(path.to_string_lossy().contains("sessions"));
|
||||
|
||||
// Platform-specific checks
|
||||
#[cfg(target_os = "linux")]
|
||||
{
|
||||
// Linux should use ~/.local/share/owlen/sessions
|
||||
assert!(path.to_string_lossy().contains(".local/share"));
|
||||
fn sample_conversation() -> Conversation {
|
||||
Conversation {
|
||||
id: Uuid::new_v4(),
|
||||
name: Some("Test conversation".to_string()),
|
||||
description: Some("A sample conversation".to_string()),
|
||||
messages: vec![
|
||||
Message::user("Hello".to_string()),
|
||||
Message::assistant("Hi".to_string()),
|
||||
],
|
||||
model: "test-model".to_string(),
|
||||
created_at: SystemTime::now(),
|
||||
updated_at: SystemTime::now(),
|
||||
}
|
||||
|
||||
#[cfg(target_os = "windows")]
|
||||
{
|
||||
// Windows should use AppData
|
||||
assert!(path.to_string_lossy().contains("AppData"));
|
||||
}
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
// macOS should use ~/Library/Application Support
|
||||
assert!(path
|
||||
.to_string_lossy()
|
||||
.contains("Library/Application Support"));
|
||||
}
|
||||
|
||||
println!("Default sessions directory: {}", path.display());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sanitize_filename() {
|
||||
assert_eq!(sanitize_filename("Hello World"), "Hello_World");
|
||||
assert_eq!(sanitize_filename("test/path\\file"), "test-path-file");
|
||||
assert_eq!(sanitize_filename("file:name?"), "file-name-");
|
||||
}
|
||||
#[tokio::test]
|
||||
async fn test_storage_lifecycle() {
|
||||
let temp_dir = tempdir().expect("failed to create temp dir");
|
||||
let db_path = temp_dir.path().join("owlen.db");
|
||||
let storage = StorageManager::with_database_path(db_path).await.unwrap();
|
||||
|
||||
#[test]
|
||||
fn test_save_and_load_conversation() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = StorageManager::with_directory(temp_dir.path().to_path_buf()).unwrap();
|
||||
let conversation = sample_conversation();
|
||||
storage
|
||||
.save_conversation(&conversation, None)
|
||||
.await
|
||||
.expect("failed to save conversation");
|
||||
|
||||
let mut conv = Conversation::new("test-model".to_string());
|
||||
conv.messages.push(Message::user("Hello".to_string()));
|
||||
conv.messages
|
||||
.push(Message::assistant("Hi there!".to_string()));
|
||||
let sessions = storage.list_sessions().await.unwrap();
|
||||
assert_eq!(sessions.len(), 1);
|
||||
assert_eq!(sessions[0].id, conversation.id);
|
||||
|
||||
// Save conversation
|
||||
let path = storage
|
||||
.save_conversation(&conv, Some("test_session".to_string()))
|
||||
.unwrap();
|
||||
assert!(path.exists());
|
||||
|
||||
// Load conversation
|
||||
let loaded = storage.load_conversation(&path).unwrap();
|
||||
assert_eq!(loaded.id, conv.id);
|
||||
assert_eq!(loaded.model, conv.model);
|
||||
let loaded = storage.load_conversation(conversation.id).await.unwrap();
|
||||
assert_eq!(loaded.messages.len(), 2);
|
||||
assert_eq!(loaded.name, Some("test_session".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_list_sessions() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = StorageManager::with_directory(temp_dir.path().to_path_buf()).unwrap();
|
||||
|
||||
// Create multiple sessions
|
||||
for i in 0..3 {
|
||||
let mut conv = Conversation::new("test-model".to_string());
|
||||
conv.messages.push(Message::user(format!("Message {}", i)));
|
||||
storage
|
||||
.save_conversation(&conv, Some(format!("session_{}", i)))
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// List sessions
|
||||
let sessions = storage.list_sessions().unwrap();
|
||||
assert_eq!(sessions.len(), 3);
|
||||
|
||||
// Check that sessions are sorted by updated_at (most recent first)
|
||||
for i in 0..sessions.len() - 1 {
|
||||
assert!(sessions[i].updated_at >= sessions[i + 1].updated_at);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_delete_session() {
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let storage = StorageManager::with_directory(temp_dir.path().to_path_buf()).unwrap();
|
||||
|
||||
let conv = Conversation::new("test-model".to_string());
|
||||
let path = storage.save_conversation(&conv, None).unwrap();
|
||||
assert!(path.exists());
|
||||
|
||||
storage.delete_session(&path).unwrap();
|
||||
assert!(!path.exists());
|
||||
storage
|
||||
.delete_session(conversation.id)
|
||||
.await
|
||||
.expect("failed to delete conversation");
|
||||
let sessions = storage.list_sessions().await.unwrap();
|
||||
assert!(sessions.is_empty());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,9 @@ pub struct Message {
|
||||
pub metadata: HashMap<String, serde_json::Value>,
|
||||
/// Timestamp when the message was created
|
||||
pub timestamp: std::time::SystemTime,
|
||||
/// Tool calls requested by the assistant
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
/// Role of a message sender
|
||||
@@ -30,6 +33,19 @@ pub enum Role {
|
||||
Assistant,
|
||||
/// System message (prompts, context, etc.)
|
||||
System,
|
||||
/// Tool response message
|
||||
Tool,
|
||||
}
|
||||
|
||||
/// A tool call requested by the assistant
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct ToolCall {
|
||||
/// Unique identifier for this tool call
|
||||
pub id: String,
|
||||
/// Name of the tool to call
|
||||
pub name: String,
|
||||
/// Arguments for the tool (JSON object)
|
||||
pub arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
impl fmt::Display for Role {
|
||||
@@ -38,6 +54,7 @@ impl fmt::Display for Role {
|
||||
Role::User => "user",
|
||||
Role::Assistant => "assistant",
|
||||
Role::System => "system",
|
||||
Role::Tool => "tool",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
@@ -72,6 +89,9 @@ pub struct ChatRequest {
|
||||
pub messages: Vec<Message>,
|
||||
/// Optional parameters for the request
|
||||
pub parameters: ChatParameters,
|
||||
/// Optional tools available for the model to use
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<crate::mcp::McpToolDescriptor>>,
|
||||
}
|
||||
|
||||
/// Parameters for chat completion
|
||||
@@ -133,6 +153,9 @@ pub struct ModelInfo {
|
||||
pub context_window: Option<u32>,
|
||||
/// Additional capabilities
|
||||
pub capabilities: Vec<String>,
|
||||
/// Whether this model supports tool/function calling
|
||||
#[serde(default)]
|
||||
pub supports_tools: bool,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
@@ -144,6 +167,7 @@ impl Message {
|
||||
content,
|
||||
metadata: HashMap::new(),
|
||||
timestamp: std::time::SystemTime::now(),
|
||||
tool_calls: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -161,6 +185,24 @@ impl Message {
|
||||
pub fn system(content: String) -> Self {
|
||||
Self::new(Role::System, content)
|
||||
}
|
||||
|
||||
/// Create a tool response message
|
||||
pub fn tool(tool_call_id: String, content: String) -> Self {
|
||||
let mut msg = Self::new(Role::Tool, content);
|
||||
msg.metadata.insert(
|
||||
"tool_call_id".to_string(),
|
||||
serde_json::Value::String(tool_call_id),
|
||||
);
|
||||
msg
|
||||
}
|
||||
|
||||
/// Check if this message has tool calls
|
||||
pub fn has_tool_calls(&self) -> bool {
|
||||
self.tool_calls
|
||||
.as_ref()
|
||||
.map(|tc| !tc.is_empty())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
}
|
||||
|
||||
impl Conversation {
|
||||
|
||||
@@ -357,8 +357,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_auto_scroll() {
|
||||
let mut scroll = AutoScroll::default();
|
||||
scroll.content_len = 100;
|
||||
let mut scroll = AutoScroll {
|
||||
content_len: 100,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Test on_viewport with stick_to_bottom
|
||||
scroll.on_viewport(10);
|
||||
|
||||
108
crates/owlen-core/src/validation.rs
Normal file
108
crates/owlen-core/src/validation.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use anyhow::{Context, Result};
|
||||
use jsonschema::{JSONSchema, ValidationError};
|
||||
use serde_json::{json, Value};
|
||||
|
||||
pub struct SchemaValidator {
|
||||
schemas: HashMap<String, JSONSchema>,
|
||||
}
|
||||
|
||||
impl Default for SchemaValidator {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl SchemaValidator {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
schemas: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn register_schema(&mut self, tool_name: &str, schema: Value) -> Result<()> {
|
||||
let compiled = JSONSchema::compile(&schema)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid schema for {}: {}", tool_name, e))?;
|
||||
|
||||
self.schemas.insert(tool_name.to_string(), compiled);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn validate(&self, tool_name: &str, input: &Value) -> Result<()> {
|
||||
let schema = self
|
||||
.schemas
|
||||
.get(tool_name)
|
||||
.with_context(|| format!("No schema registered for tool: {}", tool_name))?;
|
||||
|
||||
if let Err(errors) = schema.validate(input) {
|
||||
let error_messages: Vec<String> = errors.map(format_validation_error).collect();
|
||||
|
||||
return Err(anyhow::anyhow!(
|
||||
"Input validation failed for {}: {}",
|
||||
tool_name,
|
||||
error_messages.join(", ")
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn format_validation_error(error: ValidationError) -> String {
|
||||
format!("Validation error at {}: {}", error.instance_path, error)
|
||||
}
|
||||
|
||||
pub fn get_builtin_schemas() -> HashMap<String, Value> {
|
||||
let mut schemas = HashMap::new();
|
||||
|
||||
schemas.insert(
|
||||
"web_search".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 500
|
||||
},
|
||||
"max_results": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 10,
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
);
|
||||
|
||||
schemas.insert(
|
||||
"code_exec".to_string(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"language": {
|
||||
"type": "string",
|
||||
"enum": ["python", "javascript", "bash", "rust"]
|
||||
},
|
||||
"code": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 10000
|
||||
},
|
||||
"timeout": {
|
||||
"type": "integer",
|
||||
"minimum": 1,
|
||||
"maximum": 300,
|
||||
"default": 30
|
||||
}
|
||||
},
|
||||
"required": ["language", "code"],
|
||||
"additionalProperties": false
|
||||
}),
|
||||
);
|
||||
|
||||
schemas
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
This crate provides an implementation of the `owlen-core::Provider` trait for the [Ollama](https://ollama.ai) backend.
|
||||
|
||||
It allows Owlen to communicate with a local Ollama instance, sending requests and receiving responses from locally-run large language models.
|
||||
It allows Owlen to communicate with a local Ollama instance, sending requests and receiving responses from locally-run large language models. You can also target [Ollama Cloud](https://docs.ollama.com/cloud) by pointing the provider at `https://ollama.com` (or `https://api.ollama.com`) and providing an API key through your Owlen configuration (or the `OLLAMA_API_KEY` / `OLLAMA_CLOUD_API_KEY` environment variables). The client automatically adds the required Bearer authorization header when a key is supplied, accepts either host without rewriting, and expands inline environment references like `$OLLAMA_API_KEY` if you prefer not to check the secret into your config file. The generated configuration now includes both `providers.ollama` and `providers.ollama-cloud` entries—switch between them by updating `general.default_provider`.
|
||||
|
||||
## Configuration
|
||||
|
||||
|
||||
@@ -5,13 +5,16 @@ use owlen_core::{
|
||||
config::GeneralSettings,
|
||||
model::ModelManager,
|
||||
provider::{ChatStream, Provider, ProviderConfig},
|
||||
types::{ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage},
|
||||
types::{
|
||||
ChatParameters, ChatRequest, ChatResponse, Message, ModelInfo, Role, TokenUsage, ToolCall,
|
||||
},
|
||||
Result,
|
||||
};
|
||||
use reqwest::Client;
|
||||
use reqwest::{header, Client, Url};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
use std::env;
|
||||
use std::io;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::mpsc;
|
||||
@@ -20,26 +23,195 @@ use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
const DEFAULT_TIMEOUT_SECS: u64 = 120;
|
||||
const DEFAULT_MODEL_CACHE_TTL_SECS: u64 = 60;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum OllamaMode {
|
||||
Local,
|
||||
Cloud,
|
||||
}
|
||||
|
||||
impl OllamaMode {
|
||||
fn from_provider_type(provider_type: &str) -> Self {
|
||||
if provider_type.eq_ignore_ascii_case("ollama-cloud") {
|
||||
Self::Cloud
|
||||
} else {
|
||||
Self::Local
|
||||
}
|
||||
}
|
||||
|
||||
fn default_base_url(self) -> &'static str {
|
||||
match self {
|
||||
Self::Local => "http://localhost:11434",
|
||||
Self::Cloud => "https://ollama.com",
|
||||
}
|
||||
}
|
||||
|
||||
fn default_scheme(self) -> &'static str {
|
||||
match self {
|
||||
Self::Local => "http",
|
||||
Self::Cloud => "https",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_ollama_host(host: &str) -> bool {
|
||||
host.eq_ignore_ascii_case("ollama.com")
|
||||
|| host.eq_ignore_ascii_case("www.ollama.com")
|
||||
|| host.eq_ignore_ascii_case("api.ollama.com")
|
||||
|| host.ends_with(".ollama.com")
|
||||
}
|
||||
|
||||
fn normalize_base_url(
|
||||
input: Option<&str>,
|
||||
mode_hint: OllamaMode,
|
||||
) -> std::result::Result<String, String> {
|
||||
let mut candidate = input
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(|value| value.to_string())
|
||||
.unwrap_or_else(|| mode_hint.default_base_url().to_string());
|
||||
|
||||
if !candidate.contains("://") {
|
||||
candidate = format!("{}://{}", mode_hint.default_scheme(), candidate);
|
||||
}
|
||||
|
||||
let mut url =
|
||||
Url::parse(&candidate).map_err(|err| format!("Invalid base_url '{candidate}': {err}"))?;
|
||||
|
||||
let mut is_cloud = matches!(mode_hint, OllamaMode::Cloud);
|
||||
|
||||
if let Some(host) = url.host_str() {
|
||||
if is_ollama_host(host) {
|
||||
is_cloud = true;
|
||||
}
|
||||
}
|
||||
|
||||
if is_cloud {
|
||||
if url.scheme() != "https" {
|
||||
url.set_scheme("https")
|
||||
.map_err(|_| "Ollama Cloud requires an https URL".to_string())?;
|
||||
}
|
||||
|
||||
match url.host_str() {
|
||||
Some(host) => {
|
||||
if host.eq_ignore_ascii_case("www.ollama.com") {
|
||||
url.set_host(Some("ollama.com"))
|
||||
.map_err(|_| "Failed to normalize Ollama Cloud host".to_string())?;
|
||||
}
|
||||
}
|
||||
None => {
|
||||
return Err("Ollama Cloud base_url must include a hostname".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing slash and discard query/fragment segments
|
||||
let current_path = url.path().to_string();
|
||||
let trimmed_path = current_path.trim_end_matches('/');
|
||||
if trimmed_path.is_empty() {
|
||||
url.set_path("");
|
||||
} else {
|
||||
url.set_path(trimmed_path);
|
||||
}
|
||||
|
||||
url.set_query(None);
|
||||
url.set_fragment(None);
|
||||
|
||||
Ok(url.to_string().trim_end_matches('/').to_string())
|
||||
}
|
||||
|
||||
fn build_api_endpoint(base_url: &str, endpoint: &str) -> String {
|
||||
let trimmed_base = base_url.trim_end_matches('/');
|
||||
let trimmed_endpoint = endpoint.trim_start_matches('/');
|
||||
|
||||
if trimmed_base.ends_with("/api") {
|
||||
format!("{trimmed_base}/{trimmed_endpoint}")
|
||||
} else {
|
||||
format!("{trimmed_base}/api/{trimmed_endpoint}")
|
||||
}
|
||||
}
|
||||
|
||||
fn env_var_non_empty(name: &str) -> Option<String> {
|
||||
env::var(name)
|
||||
.ok()
|
||||
.map(|value| value.trim().to_string())
|
||||
.filter(|value| !value.is_empty())
|
||||
}
|
||||
|
||||
fn resolve_api_key(configured: Option<String>) -> Option<String> {
|
||||
let raw = configured?.trim().to_string();
|
||||
if raw.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
if let Some(variable) = raw
|
||||
.strip_prefix("${")
|
||||
.and_then(|value| value.strip_suffix('}'))
|
||||
.or_else(|| raw.strip_prefix('$'))
|
||||
{
|
||||
let var_name = variable.trim();
|
||||
if var_name.is_empty() {
|
||||
return None;
|
||||
}
|
||||
return env_var_non_empty(var_name);
|
||||
}
|
||||
|
||||
Some(raw)
|
||||
}
|
||||
|
||||
fn debug_requests_enabled() -> bool {
|
||||
std::env::var("OWLEN_DEBUG_OLLAMA")
|
||||
.ok()
|
||||
.map(|value| {
|
||||
matches!(
|
||||
value.trim(),
|
||||
"1" | "true" | "TRUE" | "True" | "yes" | "YES" | "Yes"
|
||||
)
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn mask_token(token: &str) -> String {
|
||||
if token.len() <= 8 {
|
||||
return "***".to_string();
|
||||
}
|
||||
|
||||
let head = &token[..4];
|
||||
let tail = &token[token.len() - 4..];
|
||||
format!("{head}***{tail}")
|
||||
}
|
||||
|
||||
fn mask_authorization(value: &str) -> String {
|
||||
if let Some(token) = value.strip_prefix("Bearer ") {
|
||||
format!("Bearer {}", mask_token(token))
|
||||
} else {
|
||||
"***".to_string()
|
||||
}
|
||||
}
|
||||
|
||||
/// Ollama provider implementation with enhanced configuration and caching
|
||||
#[derive(Debug)]
|
||||
pub struct OllamaProvider {
|
||||
client: Client,
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
model_manager: ModelManager,
|
||||
}
|
||||
|
||||
/// Options for configuring the Ollama provider
|
||||
pub struct OllamaOptions {
|
||||
pub base_url: String,
|
||||
pub request_timeout: Duration,
|
||||
pub model_cache_ttl: Duration,
|
||||
pub(crate) struct OllamaOptions {
|
||||
base_url: String,
|
||||
request_timeout: Duration,
|
||||
model_cache_ttl: Duration,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
impl OllamaOptions {
|
||||
pub fn new(base_url: impl Into<String>) -> Self {
|
||||
pub(crate) fn new(base_url: impl Into<String>) -> Self {
|
||||
Self {
|
||||
base_url: base_url.into(),
|
||||
request_timeout: Duration::from_secs(DEFAULT_TIMEOUT_SECS),
|
||||
model_cache_ttl: Duration::from_secs(DEFAULT_MODEL_CACHE_TTL_SECS),
|
||||
api_key: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -54,6 +226,20 @@ impl OllamaOptions {
|
||||
struct OllamaMessage {
|
||||
role: String,
|
||||
content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<OllamaToolCall>>,
|
||||
}
|
||||
|
||||
/// Ollama tool call format
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct OllamaToolCall {
|
||||
function: OllamaToolCallFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct OllamaToolCallFunction {
|
||||
name: String,
|
||||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Ollama chat request format
|
||||
@@ -62,10 +248,27 @@ struct OllamaChatRequest {
|
||||
model: String,
|
||||
messages: Vec<OllamaMessage>,
|
||||
stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<OllamaTool>>,
|
||||
#[serde(flatten)]
|
||||
options: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
/// Ollama tool definition
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct OllamaTool {
|
||||
#[serde(rename = "type")]
|
||||
tool_type: String,
|
||||
function: OllamaToolFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
struct OllamaToolFunction {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
/// Ollama chat response format
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaChatResponse {
|
||||
@@ -107,17 +310,60 @@ struct OllamaModelDetails {
|
||||
impl OllamaProvider {
|
||||
/// Create a new Ollama provider with sensible defaults
|
||||
pub fn new(base_url: impl Into<String>) -> Result<Self> {
|
||||
Self::with_options(OllamaOptions::new(base_url))
|
||||
let mode = OllamaMode::Local;
|
||||
let supplied = base_url.into();
|
||||
let normalized =
|
||||
normalize_base_url(Some(&supplied), mode).map_err(owlen_core::Error::Config)?;
|
||||
|
||||
Self::with_options(OllamaOptions::new(normalized))
|
||||
}
|
||||
|
||||
fn debug_log_request(&self, label: &str, request: &reqwest::Request, body_json: Option<&str>) {
|
||||
if !debug_requests_enabled() {
|
||||
return;
|
||||
}
|
||||
|
||||
eprintln!("--- OWLEN Ollama request ({label}) ---");
|
||||
eprintln!("{} {}", request.method(), request.url());
|
||||
|
||||
match request
|
||||
.headers()
|
||||
.get(header::AUTHORIZATION)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
{
|
||||
Some(value) => eprintln!("Authorization: {}", mask_authorization(value)),
|
||||
None => eprintln!("Authorization: <none>"),
|
||||
}
|
||||
|
||||
if let Some(body) = body_json {
|
||||
eprintln!("Body:\n{body}");
|
||||
}
|
||||
|
||||
eprintln!("---------------------------------------");
|
||||
}
|
||||
|
||||
/// Convert MCP tool descriptors to Ollama tool format
|
||||
fn convert_tools_to_ollama(tools: &[owlen_core::mcp::McpToolDescriptor]) -> Vec<OllamaTool> {
|
||||
tools
|
||||
.iter()
|
||||
.map(|tool| OllamaTool {
|
||||
tool_type: "function".to_string(),
|
||||
function: OllamaToolFunction {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.input_schema.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Create a provider from configuration settings
|
||||
pub fn from_config(config: &ProviderConfig, general: Option<&GeneralSettings>) -> Result<Self> {
|
||||
let mut options = OllamaOptions::new(
|
||||
config
|
||||
.base_url
|
||||
.clone()
|
||||
.unwrap_or_else(|| "http://localhost:11434".to_string()),
|
||||
);
|
||||
let mode = OllamaMode::from_provider_type(&config.provider_type);
|
||||
let normalized_base_url = normalize_base_url(config.base_url.as_deref(), mode)
|
||||
.map_err(owlen_core::Error::Config)?;
|
||||
|
||||
let mut options = OllamaOptions::new(normalized_base_url);
|
||||
|
||||
if let Some(timeout) = config
|
||||
.extra
|
||||
@@ -135,6 +381,10 @@ impl OllamaProvider {
|
||||
options.model_cache_ttl = Duration::from_secs(cache_ttl.max(5));
|
||||
}
|
||||
|
||||
options.api_key = resolve_api_key(config.api_key.clone())
|
||||
.or_else(|| env_var_non_empty("OLLAMA_API_KEY"))
|
||||
.or_else(|| env_var_non_empty("OLLAMA_CLOUD_API_KEY"));
|
||||
|
||||
if let Some(general) = general {
|
||||
options = options.with_general(general);
|
||||
}
|
||||
@@ -143,16 +393,24 @@ impl OllamaProvider {
|
||||
}
|
||||
|
||||
/// Create a provider from explicit options
|
||||
pub fn with_options(options: OllamaOptions) -> Result<Self> {
|
||||
pub(crate) fn with_options(options: OllamaOptions) -> Result<Self> {
|
||||
let OllamaOptions {
|
||||
base_url,
|
||||
request_timeout,
|
||||
model_cache_ttl,
|
||||
api_key,
|
||||
} = options;
|
||||
|
||||
let client = Client::builder()
|
||||
.timeout(options.request_timeout)
|
||||
.timeout(request_timeout)
|
||||
.build()
|
||||
.map_err(|e| owlen_core::Error::Config(format!("Failed to build HTTP client: {e}")))?;
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
base_url: options.base_url.trim_end_matches('/').to_string(),
|
||||
model_manager: ModelManager::new(options.model_cache_ttl),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key,
|
||||
model_manager: ModelManager::new(model_cache_ttl),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -161,14 +419,42 @@ impl OllamaProvider {
|
||||
&self.model_manager
|
||||
}
|
||||
|
||||
fn api_url(&self, endpoint: &str) -> String {
|
||||
build_api_endpoint(&self.base_url, endpoint)
|
||||
}
|
||||
|
||||
fn apply_auth(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
if let Some(api_key) = &self.api_key {
|
||||
request.bearer_auth(api_key)
|
||||
} else {
|
||||
request
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_message(message: &Message) -> OllamaMessage {
|
||||
let role = match message.role {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Assistant => "assistant".to_string(),
|
||||
Role::System => "system".to_string(),
|
||||
Role::Tool => "tool".to_string(),
|
||||
};
|
||||
|
||||
let tool_calls = message.tool_calls.as_ref().map(|calls| {
|
||||
calls
|
||||
.iter()
|
||||
.map(|tc| OllamaToolCall {
|
||||
function: OllamaToolCallFunction {
|
||||
name: tc.name.clone(),
|
||||
arguments: tc.arguments.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
});
|
||||
|
||||
OllamaMessage {
|
||||
role: match message.role {
|
||||
Role::User => "user".to_string(),
|
||||
Role::Assistant => "assistant".to_string(),
|
||||
Role::System => "system".to_string(),
|
||||
},
|
||||
role,
|
||||
content: message.content.clone(),
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -177,10 +463,27 @@ impl OllamaProvider {
|
||||
"user" => Role::User,
|
||||
"assistant" => Role::Assistant,
|
||||
"system" => Role::System,
|
||||
"tool" => Role::Tool,
|
||||
_ => Role::Assistant,
|
||||
};
|
||||
|
||||
Message::new(role, message.content.clone())
|
||||
let mut msg = Message::new(role, message.content.clone());
|
||||
|
||||
// Convert tool calls if present
|
||||
if let Some(ollama_tool_calls) = &message.tool_calls {
|
||||
let tool_calls: Vec<ToolCall> = ollama_tool_calls
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(idx, tc)| ToolCall {
|
||||
id: format!("call_{}", idx),
|
||||
name: tc.function.name.clone(),
|
||||
arguments: tc.function.arguments.clone(),
|
||||
})
|
||||
.collect();
|
||||
msg.tool_calls = Some(tool_calls);
|
||||
}
|
||||
|
||||
msg
|
||||
}
|
||||
|
||||
fn build_options(parameters: ChatParameters) -> HashMap<String, Value> {
|
||||
@@ -202,11 +505,10 @@ impl OllamaProvider {
|
||||
}
|
||||
|
||||
async fn fetch_models(&self) -> Result<Vec<ModelInfo>> {
|
||||
let url = format!("{}/api/tags", self.base_url);
|
||||
let url = self.api_url("tags");
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.apply_auth(self.client.get(&url))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| owlen_core::Error::Network(format!("Failed to fetch models: {e}")))?;
|
||||
@@ -229,21 +531,51 @@ impl OllamaProvider {
|
||||
let models = ollama_response
|
||||
.models
|
||||
.into_iter()
|
||||
.map(|model| ModelInfo {
|
||||
id: model.name.clone(),
|
||||
name: model.name.clone(),
|
||||
description: model
|
||||
.details
|
||||
.as_ref()
|
||||
.and_then(|d| d.family.as_ref().map(|f| format!("Ollama {f} model"))),
|
||||
provider: "ollama".to_string(),
|
||||
context_window: None,
|
||||
capabilities: vec!["chat".to_string()],
|
||||
.map(|model| {
|
||||
// Check if model supports tool calling based on known models
|
||||
let supports_tools = Self::check_tool_support(&model.name);
|
||||
|
||||
ModelInfo {
|
||||
id: model.name.clone(),
|
||||
name: model.name.clone(),
|
||||
description: model
|
||||
.details
|
||||
.as_ref()
|
||||
.and_then(|d| d.family.as_ref().map(|f| format!("Ollama {f} model"))),
|
||||
provider: "ollama".to_string(),
|
||||
context_window: None,
|
||||
capabilities: vec!["chat".to_string()],
|
||||
supports_tools,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(models)
|
||||
}
|
||||
|
||||
/// Check if a model supports tool calling based on its name
|
||||
fn check_tool_support(model_name: &str) -> bool {
|
||||
let name_lower = model_name.to_lowercase();
|
||||
|
||||
// Known models with tool calling support
|
||||
let tool_supporting_models = [
|
||||
"qwen",
|
||||
"llama3.1",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"mistral-nemo",
|
||||
"mistral:7b-instruct",
|
||||
"command-r",
|
||||
"firefunction",
|
||||
"hermes",
|
||||
"nexusraven",
|
||||
"granite-code",
|
||||
];
|
||||
|
||||
tool_supporting_models
|
||||
.iter()
|
||||
.any(|&supported| name_lower.contains(supported))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
@@ -263,25 +595,42 @@ impl Provider for OllamaProvider {
|
||||
model,
|
||||
messages,
|
||||
parameters,
|
||||
tools,
|
||||
} = request;
|
||||
|
||||
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
||||
|
||||
let options = Self::build_options(parameters);
|
||||
|
||||
let ollama_tools = tools.as_ref().map(|t| Self::convert_tools_to_ollama(t));
|
||||
|
||||
let ollama_request = OllamaChatRequest {
|
||||
model,
|
||||
messages,
|
||||
stream: false,
|
||||
tools: ollama_tools,
|
||||
options,
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let url = self.api_url("chat");
|
||||
let debug_body = if debug_requests_enabled() {
|
||||
serde_json::to_string_pretty(&ollama_request).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut request_builder = self.client.post(&url).json(&ollama_request);
|
||||
request_builder = self.apply_auth(request_builder);
|
||||
|
||||
let request = request_builder.build().map_err(|e| {
|
||||
owlen_core::Error::Network(format!("Failed to build chat request: {e}"))
|
||||
})?;
|
||||
|
||||
self.debug_log_request("chat", &request, debug_body.as_deref());
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&ollama_request)
|
||||
.send()
|
||||
.execute(request)
|
||||
.await
|
||||
.map_err(|e| owlen_core::Error::Network(format!("Chat request failed: {e}")))?;
|
||||
|
||||
@@ -339,28 +688,43 @@ impl Provider for OllamaProvider {
|
||||
model,
|
||||
messages,
|
||||
parameters,
|
||||
tools,
|
||||
} = request;
|
||||
|
||||
let messages: Vec<OllamaMessage> = messages.iter().map(Self::convert_message).collect();
|
||||
|
||||
let options = Self::build_options(parameters);
|
||||
|
||||
let ollama_tools = tools.as_ref().map(|t| Self::convert_tools_to_ollama(t));
|
||||
|
||||
let ollama_request = OllamaChatRequest {
|
||||
model,
|
||||
messages,
|
||||
stream: true,
|
||||
tools: ollama_tools,
|
||||
options,
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let url = self.api_url("chat");
|
||||
let debug_body = if debug_requests_enabled() {
|
||||
serde_json::to_string_pretty(&ollama_request).ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(&url)
|
||||
.json(&ollama_request)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| owlen_core::Error::Network(format!("Streaming request failed: {e}")))?;
|
||||
let mut request_builder = self.client.post(&url).json(&ollama_request);
|
||||
request_builder = self.apply_auth(request_builder);
|
||||
|
||||
let request = request_builder.build().map_err(|e| {
|
||||
owlen_core::Error::Network(format!("Failed to build streaming request: {e}"))
|
||||
})?;
|
||||
|
||||
self.debug_log_request("chat_stream", &request, debug_body.as_deref());
|
||||
|
||||
let response =
|
||||
self.client.execute(request).await.map_err(|e| {
|
||||
owlen_core::Error::Network(format!("Streaming request failed: {e}"))
|
||||
})?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let code = response.status();
|
||||
@@ -462,11 +826,10 @@ impl Provider for OllamaProvider {
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> Result<()> {
|
||||
let url = format!("{}/api/version", self.base_url);
|
||||
let url = self.api_url("version");
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.get(&url)
|
||||
.apply_auth(self.client.get(&url))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| owlen_core::Error::Network(format!("Health check failed: {e}")))?;
|
||||
@@ -528,3 +891,86 @@ async fn parse_error_body(response: reqwest::Response) -> String {
|
||||
Err(_) => "unknown error".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn normalizes_local_base_url_and_infers_scheme() {
|
||||
let normalized =
|
||||
normalize_base_url(Some("localhost:11434"), OllamaMode::Local).expect("valid URL");
|
||||
assert_eq!(normalized, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalizes_cloud_base_url_and_host() {
|
||||
let normalized =
|
||||
normalize_base_url(Some("https://ollama.com"), OllamaMode::Cloud).expect("valid URL");
|
||||
assert_eq!(normalized, "https://ollama.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infers_scheme_for_cloud_hosts() {
|
||||
let normalized =
|
||||
normalize_base_url(Some("ollama.com"), OllamaMode::Cloud).expect("valid URL");
|
||||
assert_eq!(normalized, "https://ollama.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rewrites_www_cloud_host() {
|
||||
let normalized = normalize_base_url(Some("https://www.ollama.com"), OllamaMode::Cloud)
|
||||
.expect("valid URL");
|
||||
assert_eq!(normalized, "https://ollama.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn retains_explicit_api_suffix() {
|
||||
let normalized = normalize_base_url(Some("https://api.ollama.com/api"), OllamaMode::Cloud)
|
||||
.expect("valid URL");
|
||||
assert_eq!(normalized, "https://api.ollama.com/api");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn builds_api_endpoint_without_duplicate_segments() {
|
||||
let base = "http://localhost:11434";
|
||||
assert_eq!(
|
||||
build_api_endpoint(base, "chat"),
|
||||
"http://localhost:11434/api/chat"
|
||||
);
|
||||
|
||||
let base_with_api = "http://localhost:11434/api";
|
||||
assert_eq!(
|
||||
build_api_endpoint(base_with_api, "chat"),
|
||||
"http://localhost:11434/api/chat"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_literal_value() {
|
||||
assert_eq!(
|
||||
resolve_api_key(Some("direct-key".into())),
|
||||
Some("direct-key".into())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_expands_braced_env_reference() {
|
||||
std::env::set_var("OWLEN_TEST_KEY", "super-secret");
|
||||
assert_eq!(
|
||||
resolve_api_key(Some("${OWLEN_TEST_KEY}".into())),
|
||||
Some("super-secret".into())
|
||||
);
|
||||
std::env::remove_var("OWLEN_TEST_KEY");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_expands_unbraced_env_reference() {
|
||||
std::env::set_var("OWLEN_TEST_KEY", "another-secret");
|
||||
assert_eq!(
|
||||
resolve_api_key(Some("$OWLEN_TEST_KEY".into())),
|
||||
Some("another-secret".into())
|
||||
);
|
||||
std::env::remove_var("OWLEN_TEST_KEY");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ description = "Terminal User Interface for OWLEN LLM client"
|
||||
|
||||
[dependencies]
|
||||
owlen-core = { path = "../owlen-core" }
|
||||
owlen-ollama = { path = "../owlen-ollama" }
|
||||
|
||||
# TUI framework
|
||||
ratatui = { workspace = true }
|
||||
@@ -26,6 +27,7 @@ futures-util = { workspace = true }
|
||||
# Utilities
|
||||
anyhow = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
serde_json.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
tokio-test = { workspace = true }
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,12 +14,14 @@ pub struct CodeApp {
|
||||
}
|
||||
|
||||
impl CodeApp {
|
||||
pub fn new(mut controller: SessionController) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
|
||||
pub async fn new(
|
||||
mut controller: SessionController,
|
||||
) -> Result<(Self, mpsc::UnboundedReceiver<SessionEvent>)> {
|
||||
controller
|
||||
.conversation_mut()
|
||||
.push_system_message(DEFAULT_SYSTEM_PROMPT.to_string());
|
||||
let (inner, rx) = ChatApp::new(controller);
|
||||
(Self { inner }, rx)
|
||||
let (inner, rx) = ChatApp::new(controller).await?;
|
||||
Ok((Self { inner }, rx))
|
||||
}
|
||||
|
||||
pub async fn handle_event(&mut self, event: Event) -> Result<AppState> {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
pub use owlen_core::config::{
|
||||
default_config_path, ensure_ollama_config, session_timeout, Config, GeneralSettings,
|
||||
InputSettings, StorageSettings, UiSettings, DEFAULT_CONFIG_PATH,
|
||||
default_config_path, ensure_ollama_config, ensure_provider_config, session_timeout, Config,
|
||||
GeneralSettings, InputSettings, StorageSettings, UiSettings, DEFAULT_CONFIG_PATH,
|
||||
};
|
||||
|
||||
/// Attempt to load configuration from default location
|
||||
|
||||
@@ -3,14 +3,17 @@ use ratatui::style::{Color, Modifier, Style};
|
||||
use ratatui::text::{Line, Span};
|
||||
use ratatui::widgets::{Block, Borders, Clear, List, ListItem, ListState, Paragraph, Wrap};
|
||||
use ratatui::Frame;
|
||||
use serde_json;
|
||||
use textwrap::{wrap, Options};
|
||||
use tui_textarea::TextArea;
|
||||
use unicode_width::UnicodeWidthStr;
|
||||
|
||||
use crate::chat_app::ChatApp;
|
||||
use crate::chat_app::{ChatApp, ModelSelectorItemKind, HELP_TAB_COUNT};
|
||||
use owlen_core::types::Role;
|
||||
use owlen_core::ui::{FocusedPanel, InputMode};
|
||||
|
||||
const PRIVACY_TAB_INDEX: usize = HELP_TAB_COUNT - 1;
|
||||
|
||||
pub fn render_chat(frame: &mut Frame<'_>, app: &mut ChatApp) {
|
||||
// Update thinking content from last message
|
||||
app.update_thinking_from_last_message();
|
||||
@@ -82,14 +85,19 @@ pub fn render_chat(frame: &mut Frame<'_>, app: &mut ChatApp) {
|
||||
|
||||
render_status(frame, layout[idx], app);
|
||||
|
||||
match app.mode() {
|
||||
InputMode::ProviderSelection => render_provider_selector(frame, app),
|
||||
InputMode::ModelSelection => render_model_selector(frame, app),
|
||||
InputMode::Help => render_help(frame, app),
|
||||
InputMode::SessionBrowser => render_session_browser(frame, app),
|
||||
InputMode::ThemeBrowser => render_theme_browser(frame, app),
|
||||
InputMode::Command => render_command_suggestions(frame, app),
|
||||
_ => {}
|
||||
// Render consent dialog with highest priority (always on top)
|
||||
if app.has_pending_consent() {
|
||||
render_consent_dialog(frame, app);
|
||||
} else {
|
||||
match app.mode() {
|
||||
InputMode::ProviderSelection => render_provider_selector(frame, app),
|
||||
InputMode::ModelSelection => render_model_selector(frame, app),
|
||||
InputMode::Help => render_help(frame, app),
|
||||
InputMode::SessionBrowser => render_session_browser(frame, app),
|
||||
InputMode::ThemeBrowser => render_theme_browser(frame, app),
|
||||
InputMode::Command => render_command_suggestions(frame, app),
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -600,12 +608,16 @@ fn render_messages(frame: &mut Frame<'_>, area: Rect, app: &mut ChatApp) {
|
||||
Role::User => ("👤 ", "You: "),
|
||||
Role::Assistant => ("🤖 ", "Assistant: "),
|
||||
Role::System => ("⚙️ ", "System: "),
|
||||
Role::Tool => ("🔧 ", "Tool: "),
|
||||
};
|
||||
|
||||
// Extract content without thinking tags for assistant messages
|
||||
let content_to_display = if matches!(role, Role::Assistant) {
|
||||
let (content_without_think, _) = formatter.extract_thinking(&message.content);
|
||||
content_without_think
|
||||
} else if matches!(role, Role::Tool) {
|
||||
// Format tool results nicely
|
||||
format_tool_output(&message.content)
|
||||
} else {
|
||||
message.content.clone()
|
||||
};
|
||||
@@ -1102,20 +1114,49 @@ fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
frame.render_widget(Clear, area);
|
||||
|
||||
let items: Vec<ListItem> = app
|
||||
.models()
|
||||
.model_selector_items()
|
||||
.iter()
|
||||
.map(|model| {
|
||||
let label = if model.name.is_empty() {
|
||||
model.id.clone()
|
||||
} else {
|
||||
format!("{} — {}", model.id, model.name)
|
||||
};
|
||||
ListItem::new(Span::styled(
|
||||
label,
|
||||
.map(|item| match item.kind() {
|
||||
ModelSelectorItemKind::Header { provider, expanded } => {
|
||||
let marker = if *expanded { "▼" } else { "▶" };
|
||||
let label = format!("{} {}", marker, provider);
|
||||
ListItem::new(Span::styled(
|
||||
label,
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
))
|
||||
}
|
||||
ModelSelectorItemKind::Model {
|
||||
provider: _,
|
||||
model_index,
|
||||
} => {
|
||||
if let Some(model) = app.model_info_by_index(*model_index) {
|
||||
let tool_indicator = if model.supports_tools { "🔧 " } else { " " };
|
||||
let label = if model.name.is_empty() {
|
||||
format!(" {}{}", tool_indicator, model.id)
|
||||
} else {
|
||||
format!(" {}{} — {}", tool_indicator, model.id, model.name)
|
||||
};
|
||||
ListItem::new(Span::styled(
|
||||
label,
|
||||
Style::default()
|
||||
.fg(theme.user_message_role)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
))
|
||||
} else {
|
||||
ListItem::new(Span::styled(
|
||||
" <model unavailable>",
|
||||
Style::default().fg(theme.error),
|
||||
))
|
||||
}
|
||||
}
|
||||
ModelSelectorItemKind::Empty { provider } => ListItem::new(Span::styled(
|
||||
format!(" (no models configured for {provider})"),
|
||||
Style::default()
|
||||
.fg(theme.user_message_role)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
))
|
||||
.fg(theme.unfocused_panel_border)
|
||||
.add_modifier(Modifier::ITALIC),
|
||||
)),
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -1123,7 +1164,7 @@ fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
.block(
|
||||
Block::default()
|
||||
.title(Span::styled(
|
||||
format!("Select Model ({})", app.selected_provider),
|
||||
"Select Model — 🔧 = Tool Support",
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
@@ -1139,10 +1180,193 @@ fn render_model_selector(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
.highlight_symbol("▶ ");
|
||||
|
||||
let mut state = ListState::default();
|
||||
state.select(app.selected_model_index());
|
||||
state.select(app.selected_model_item());
|
||||
frame.render_stateful_widget(list, area, &mut state);
|
||||
}
|
||||
|
||||
fn render_consent_dialog(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
let theme = app.theme();
|
||||
|
||||
// Get consent dialog state
|
||||
let consent_state = match app.consent_dialog() {
|
||||
Some(state) => state,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// Create centered modal area
|
||||
let area = centered_rect(70, 50, frame.area());
|
||||
frame.render_widget(Clear, area);
|
||||
|
||||
// Build consent dialog content
|
||||
let mut lines = vec![
|
||||
Line::from(vec![
|
||||
Span::styled("🔒 ", Style::default().fg(theme.focused_panel_border)),
|
||||
Span::styled(
|
||||
"Consent Required",
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
]),
|
||||
Line::from(""),
|
||||
Line::from(vec![
|
||||
Span::styled("Tool: ", Style::default().add_modifier(Modifier::BOLD)),
|
||||
Span::styled(
|
||||
consent_state.tool_name.clone(),
|
||||
Style::default().fg(theme.user_message_role),
|
||||
),
|
||||
]),
|
||||
Line::from(""),
|
||||
];
|
||||
|
||||
// Add data types if any
|
||||
if !consent_state.data_types.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
"Data Access:",
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
)));
|
||||
for data_type in &consent_state.data_types {
|
||||
lines.push(Line::from(vec![
|
||||
Span::raw(" • "),
|
||||
Span::styled(data_type, Style::default().fg(theme.text)),
|
||||
]));
|
||||
}
|
||||
lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
// Add endpoints if any
|
||||
if !consent_state.endpoints.is_empty() {
|
||||
lines.push(Line::from(Span::styled(
|
||||
"Endpoints:",
|
||||
Style::default().add_modifier(Modifier::BOLD),
|
||||
)));
|
||||
for endpoint in &consent_state.endpoints {
|
||||
lines.push(Line::from(vec![
|
||||
Span::raw(" • "),
|
||||
Span::styled(endpoint, Style::default().fg(theme.text)),
|
||||
]));
|
||||
}
|
||||
lines.push(Line::from(""));
|
||||
}
|
||||
|
||||
// Add prompt
|
||||
lines.push(Line::from(""));
|
||||
lines.push(Line::from(vec![Span::styled(
|
||||
"Allow this tool to execute?",
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
)]));
|
||||
lines.push(Line::from(""));
|
||||
lines.push(Line::from(vec![
|
||||
Span::styled(
|
||||
"[Y] ",
|
||||
Style::default()
|
||||
.fg(Color::Green)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::raw("Allow "),
|
||||
Span::styled(
|
||||
"[N] ",
|
||||
Style::default().fg(Color::Red).add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::raw("Deny "),
|
||||
Span::styled(
|
||||
"[Esc] ",
|
||||
Style::default()
|
||||
.fg(Color::Yellow)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
),
|
||||
Span::raw("Cancel"),
|
||||
]));
|
||||
|
||||
let paragraph = Paragraph::new(lines)
|
||||
.block(
|
||||
Block::default()
|
||||
.title(Span::styled(
|
||||
" Consent Dialog ",
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
))
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(theme.focused_panel_border))
|
||||
.style(Style::default().bg(theme.background)),
|
||||
)
|
||||
.alignment(Alignment::Left)
|
||||
.wrap(Wrap { trim: true });
|
||||
|
||||
frame.render_widget(paragraph, area);
|
||||
}
|
||||
|
||||
fn render_privacy_settings(frame: &mut Frame<'_>, area: Rect, app: &ChatApp) {
|
||||
let theme = app.theme();
|
||||
let config = app.config();
|
||||
|
||||
let block = Block::default()
|
||||
.title("Privacy Settings")
|
||||
.borders(Borders::ALL)
|
||||
.border_style(Style::default().fg(theme.unfocused_panel_border))
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
let inner = block.inner(area);
|
||||
frame.render_widget(block, area);
|
||||
|
||||
let remote_search_enabled =
|
||||
config.privacy.enable_remote_search && config.tools.web_search.enabled;
|
||||
let code_exec_enabled = config.tools.code_exec.enabled;
|
||||
let history_days = config.privacy.retain_history_days;
|
||||
let cache_results = config.privacy.cache_web_results;
|
||||
let consent_required = config.privacy.require_consent_per_session;
|
||||
let encryption_enabled = config.privacy.encrypt_local_data;
|
||||
|
||||
let status_line = |label: &str, enabled: bool| {
|
||||
let status_text = if enabled { "Enabled" } else { "Disabled" };
|
||||
let status_style = if enabled {
|
||||
Style::default().fg(theme.selection_fg)
|
||||
} else {
|
||||
Style::default().fg(theme.error)
|
||||
};
|
||||
Line::from(vec![
|
||||
Span::raw(format!(" {label}: ")),
|
||||
Span::styled(status_text, status_style),
|
||||
])
|
||||
};
|
||||
|
||||
let mut lines = Vec::new();
|
||||
lines.push(Line::from(vec![Span::styled(
|
||||
"Privacy Configuration",
|
||||
Style::default().fg(theme.info).add_modifier(Modifier::BOLD),
|
||||
)]));
|
||||
lines.push(Line::raw(""));
|
||||
lines.push(Line::from("Network Access:"));
|
||||
lines.push(status_line("Web Search", remote_search_enabled));
|
||||
lines.push(status_line("Code Execution", code_exec_enabled));
|
||||
lines.push(Line::raw(""));
|
||||
lines.push(Line::from("Data Retention:"));
|
||||
lines.push(Line::from(format!(
|
||||
" History retention: {} day(s)",
|
||||
history_days
|
||||
)));
|
||||
lines.push(Line::from(format!(
|
||||
" Cache web results: {}",
|
||||
if cache_results { "Yes" } else { "No" }
|
||||
)));
|
||||
lines.push(Line::raw(""));
|
||||
lines.push(Line::from("Safeguards:"));
|
||||
lines.push(status_line("Consent required", consent_required));
|
||||
lines.push(status_line("Encrypted storage", encryption_enabled));
|
||||
lines.push(Line::raw(""));
|
||||
lines.push(Line::from("Commands:"));
|
||||
lines.push(Line::from(" :privacy-enable <tool> - Enable tool"));
|
||||
lines.push(Line::from(" :privacy-disable <tool> - Disable tool"));
|
||||
lines.push(Line::from(" :privacy-clear - Clear all data"));
|
||||
|
||||
let paragraph = Paragraph::new(lines)
|
||||
.wrap(Wrap { trim: true })
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
frame.render_widget(paragraph, inner);
|
||||
}
|
||||
|
||||
fn render_help(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
let theme = app.theme();
|
||||
let area = centered_rect(75, 70, frame.area());
|
||||
@@ -1156,6 +1380,7 @@ fn render_help(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
"Commands",
|
||||
"Sessions",
|
||||
"Browsers",
|
||||
"Privacy",
|
||||
];
|
||||
|
||||
// Build tab line
|
||||
@@ -1429,6 +1654,7 @@ fn render_help(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
Line::from(" g / Home → jump to top"),
|
||||
Line::from(" G / End → jump to bottom"),
|
||||
],
|
||||
6 => vec![],
|
||||
|
||||
_ => vec![],
|
||||
};
|
||||
@@ -1454,14 +1680,18 @@ fn render_help(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
frame.render_widget(tabs_para, layout[0]);
|
||||
|
||||
// Render content
|
||||
let content_block = Block::default()
|
||||
.borders(Borders::LEFT | Borders::RIGHT)
|
||||
.border_style(Style::default().fg(theme.unfocused_panel_border))
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
let content_para = Paragraph::new(help_text)
|
||||
.style(Style::default().bg(theme.background).fg(theme.text))
|
||||
.block(content_block);
|
||||
frame.render_widget(content_para, layout[1]);
|
||||
if tab_index == PRIVACY_TAB_INDEX {
|
||||
render_privacy_settings(frame, layout[1], app);
|
||||
} else {
|
||||
let content_block = Block::default()
|
||||
.borders(Borders::LEFT | Borders::RIGHT)
|
||||
.border_style(Style::default().fg(theme.unfocused_panel_border))
|
||||
.style(Style::default().bg(theme.background).fg(theme.text));
|
||||
let content_para = Paragraph::new(help_text)
|
||||
.style(Style::default().bg(theme.background).fg(theme.text))
|
||||
.block(content_block);
|
||||
frame.render_widget(content_para, layout[1]);
|
||||
}
|
||||
|
||||
// Render navigation hint
|
||||
let nav_hint = Line::from(vec![
|
||||
@@ -1474,7 +1704,7 @@ fn render_help(frame: &mut Frame<'_>, app: &ChatApp) {
|
||||
),
|
||||
Span::raw(":Switch "),
|
||||
Span::styled(
|
||||
"1-6",
|
||||
format!("1-{}", HELP_TAB_COUNT),
|
||||
Style::default()
|
||||
.fg(theme.focused_panel_border)
|
||||
.add_modifier(Modifier::BOLD),
|
||||
@@ -1846,5 +2076,96 @@ fn role_color(role: &Role, theme: &owlen_core::theme::Theme) -> Style {
|
||||
Role::User => Style::default().fg(theme.user_message_role),
|
||||
Role::Assistant => Style::default().fg(theme.assistant_message_role),
|
||||
Role::System => Style::default().fg(theme.info),
|
||||
Role::Tool => Style::default().fg(theme.info),
|
||||
}
|
||||
}
|
||||
|
||||
/// Format tool output JSON into a nice human-readable format
|
||||
fn format_tool_output(content: &str) -> String {
|
||||
// Try to parse as JSON
|
||||
if let Ok(json) = serde_json::from_str::<serde_json::Value>(content) {
|
||||
let mut output = String::new();
|
||||
|
||||
// Extract query if present
|
||||
if let Some(query) = json.get("query").and_then(|v| v.as_str()) {
|
||||
output.push_str(&format!("Query: \"{}\"\n\n", query));
|
||||
}
|
||||
|
||||
// Extract results array
|
||||
if let Some(results) = json.get("results").and_then(|v| v.as_array()) {
|
||||
if results.is_empty() {
|
||||
output.push_str("No results found");
|
||||
return output;
|
||||
}
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
// Title
|
||||
if let Some(title) = result.get("title").and_then(|v| v.as_str()) {
|
||||
// Strip HTML tags from title
|
||||
let clean_title = title.replace("<b>", "").replace("</b>", "");
|
||||
output.push_str(&format!("{}. {}\n", i + 1, clean_title));
|
||||
}
|
||||
|
||||
// Source and date (if available)
|
||||
let mut meta = Vec::new();
|
||||
if let Some(source) = result.get("source").and_then(|v| v.as_str()) {
|
||||
meta.push(format!("📰 {}", source));
|
||||
}
|
||||
if let Some(date) = result.get("date").and_then(|v| v.as_str()) {
|
||||
// Simplify date format
|
||||
if let Some(simple_date) = date.split('T').next() {
|
||||
meta.push(format!("📅 {}", simple_date));
|
||||
}
|
||||
}
|
||||
if !meta.is_empty() {
|
||||
output.push_str(&format!(" {}\n", meta.join(" • ")));
|
||||
}
|
||||
|
||||
// Snippet (truncated if too long)
|
||||
if let Some(snippet) = result.get("snippet").and_then(|v| v.as_str()) {
|
||||
if !snippet.is_empty() {
|
||||
// Strip HTML tags
|
||||
let clean_snippet = snippet
|
||||
.replace("<b>", "")
|
||||
.replace("</b>", "")
|
||||
.replace("'", "'")
|
||||
.replace(""", "\"");
|
||||
|
||||
// Truncate if too long
|
||||
let truncated = if clean_snippet.len() > 200 {
|
||||
format!("{}...", &clean_snippet[..197])
|
||||
} else {
|
||||
clean_snippet
|
||||
};
|
||||
output.push_str(&format!(" {}\n", truncated));
|
||||
}
|
||||
}
|
||||
|
||||
// URL (shortened if too long)
|
||||
if let Some(url) = result.get("url").and_then(|v| v.as_str()) {
|
||||
let display_url = if url.len() > 80 {
|
||||
format!("{}...", &url[..77])
|
||||
} else {
|
||||
url.to_string()
|
||||
};
|
||||
output.push_str(&format!(" 🔗 {}\n", display_url));
|
||||
}
|
||||
|
||||
output.push('\n');
|
||||
}
|
||||
|
||||
// Add total count
|
||||
if let Some(total) = json.get("total_found").and_then(|v| v.as_u64()) {
|
||||
output.push_str(&format!("Found {} result(s)", total));
|
||||
}
|
||||
} else if let Some(error) = json.get("error").and_then(|v| v.as_str()) {
|
||||
// Handle error results
|
||||
output.push_str(&format!("❌ Error: {}", error));
|
||||
}
|
||||
|
||||
output
|
||||
} else {
|
||||
// If not JSON, return as-is
|
||||
content.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user