Add App core struct with event-handling and initialization logic for TUI.
This commit is contained in:
542
crates/owlen-tui/src/chat_app.rs
Normal file
542
crates/owlen-tui/src/chat_app.rs
Normal file
@@ -0,0 +1,542 @@
|
||||
use anyhow::Result;
|
||||
use owlen_core::{
|
||||
session::{SessionController, SessionOutcome},
|
||||
types::{ChatParameters, ChatResponse, Conversation, ModelInfo},
|
||||
};
|
||||
use tokio::sync::mpsc;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::config;
|
||||
use crate::events::Event;
|
||||
use std::collections::HashSet;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AppState {
|
||||
Running,
|
||||
Quit,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum InputMode {
|
||||
Normal,
|
||||
Editing,
|
||||
ProviderSelection,
|
||||
ModelSelection,
|
||||
Help,
|
||||
}
|
||||
|
||||
impl fmt::Display for InputMode {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
let label = match self {
|
||||
InputMode::Normal => "Normal",
|
||||
InputMode::Editing => "Editing",
|
||||
InputMode::ModelSelection => "Model",
|
||||
InputMode::ProviderSelection => "Provider",
|
||||
InputMode::Help => "Help",
|
||||
};
|
||||
f.write_str(label)
|
||||
}
|
||||
}
|
||||
|
||||
/// Messages emitted by asynchronous streaming tasks
|
||||
#[derive(Debug)]
|
||||
pub enum SessionEvent {
|
||||
StreamChunk {
|
||||
message_id: Uuid,
|
||||
response: ChatResponse,
|
||||
},
|
||||
StreamError {
|
||||
message: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct ChatApp {
|
||||
controller: SessionController,
|
||||
pub mode: InputMode,
|
||||
pub status: String,
|
||||
pub error: Option<String>,
|
||||
models: Vec<ModelInfo>, // All models fetched
|
||||
pub available_providers: Vec<String>, // Unique providers from models
|
||||
pub selected_provider: String, // The currently selected provider
|
||||
pub selected_provider_index: usize, // Index into the available_providers list
|
||||
pub selected_model: Option<usize>, // Index into the *filtered* models list
|
||||
scroll: usize,
|
||||
session_tx: mpsc::UnboundedSender<SessionEvent>,
|
||||
streaming: HashSet<Uuid>,
|
||||
}
|
||||
|
||||
impl ChatApp {
|
||||
pub fn new(controller: SessionController) -> (Self, mpsc::UnboundedReceiver<SessionEvent>) {
|
||||
let (session_tx, session_rx) = mpsc::unbounded_channel();
|
||||
let app = Self {
|
||||
controller,
|
||||
mode: InputMode::Normal,
|
||||
status: "Ready".to_string(),
|
||||
error: None,
|
||||
models: Vec::new(),
|
||||
available_providers: Vec::new(),
|
||||
selected_provider: "ollama".to_string(), // Default, will be updated in initialize_models
|
||||
selected_provider_index: 0,
|
||||
selected_model: None,
|
||||
scroll: 0,
|
||||
session_tx,
|
||||
streaming: std::collections::HashSet::new(),
|
||||
};
|
||||
|
||||
(app, session_rx)
|
||||
}
|
||||
|
||||
pub fn status_message(&self) -> &str {
|
||||
&self.status
|
||||
}
|
||||
|
||||
pub fn error_message(&self) -> Option<&String> {
|
||||
self.error.as_ref()
|
||||
}
|
||||
|
||||
pub fn mode(&self) -> InputMode {
|
||||
self.mode
|
||||
}
|
||||
|
||||
pub fn conversation(&self) -> &Conversation {
|
||||
self.controller.conversation()
|
||||
}
|
||||
|
||||
pub fn models(&self) -> Vec<&ModelInfo> {
|
||||
self.models.iter()
|
||||
.filter(|m| m.provider == self.selected_provider)
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn selected_model(&self) -> &str {
|
||||
self.controller.selected_model()
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &owlen_core::config::Config {
|
||||
self.controller.config()
|
||||
}
|
||||
|
||||
pub fn selected_model_index(&self) -> Option<usize> {
|
||||
self.selected_model
|
||||
}
|
||||
|
||||
pub fn scroll(&self) -> usize {
|
||||
self.scroll
|
||||
}
|
||||
|
||||
pub fn message_count(&self) -> usize {
|
||||
self.controller.conversation().messages.len()
|
||||
}
|
||||
|
||||
pub fn streaming_count(&self) -> usize {
|
||||
self.streaming.len()
|
||||
}
|
||||
|
||||
pub fn formatter(&self) -> &owlen_core::formatting::MessageFormatter {
|
||||
self.controller.formatter()
|
||||
}
|
||||
|
||||
pub fn input_buffer(&self) -> &owlen_core::input::InputBuffer {
|
||||
self.controller.input_buffer()
|
||||
}
|
||||
|
||||
pub fn input_buffer_mut(&mut self) -> &mut owlen_core::input::InputBuffer {
|
||||
self.controller.input_buffer_mut()
|
||||
}
|
||||
|
||||
pub async fn initialize_models(&mut self) -> Result<()> {
|
||||
let config_model_name = self.controller.config().general.default_model.clone();
|
||||
let config_model_provider = self.controller.config().general.default_provider.clone();
|
||||
|
||||
let all_models = self.controller.models(false).await?;
|
||||
self.models = all_models;
|
||||
|
||||
// Populate available_providers
|
||||
let mut providers = self.models.iter().map(|m| m.provider.clone()).collect::<HashSet<_>>();
|
||||
self.available_providers = providers.into_iter().collect();
|
||||
self.available_providers.sort();
|
||||
|
||||
// Set selected_provider based on config, or default to "ollama" if not found
|
||||
self.selected_provider = self.available_providers.iter()
|
||||
.find(|&p| p == &config_model_provider)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "ollama".to_string());
|
||||
self.selected_provider_index = self.available_providers.iter()
|
||||
.position(|p| p == &self.selected_provider)
|
||||
.unwrap_or(0);
|
||||
|
||||
self.sync_selected_model_index();
|
||||
|
||||
// Ensure the default model is set in the controller and config
|
||||
self.controller.ensure_default_model(&self.models);
|
||||
|
||||
let current_model_name = self.controller.selected_model().to_string();
|
||||
let current_model_provider = self.controller.config().general.default_provider.clone();
|
||||
|
||||
if config_model_name.as_deref() != Some(¤t_model_name) || config_model_provider != current_model_provider {
|
||||
if let Err(err) = config::save_config(self.controller.config()) {
|
||||
self.error = Some(format!("Failed to save config: {err}"));
|
||||
} else {
|
||||
self.error = None;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_event(&mut self, event: Event) -> Result<AppState> {
|
||||
use crossterm::event::{KeyCode, KeyModifiers};
|
||||
|
||||
match event {
|
||||
Event::Tick => {
|
||||
// Future: update streaming timers
|
||||
}
|
||||
Event::Key(key) => match self.mode {
|
||||
InputMode::Normal => match (key.code, key.modifiers) {
|
||||
(KeyCode::Char('q'), KeyModifiers::NONE)
|
||||
| (KeyCode::Char('c'), KeyModifiers::CONTROL) => {
|
||||
return Ok(AppState::Quit);
|
||||
}
|
||||
(KeyCode::Char('m'), KeyModifiers::NONE) => {
|
||||
self.refresh_models().await?;
|
||||
self.mode = InputMode::ProviderSelection;
|
||||
}
|
||||
(KeyCode::Char('n'), KeyModifiers::NONE) => {
|
||||
self.controller.start_new_conversation(None, None);
|
||||
self.status = "Started new conversation".to_string();
|
||||
}
|
||||
(KeyCode::Char('h'), KeyModifiers::NONE) => {
|
||||
self.mode = InputMode::Help;
|
||||
}
|
||||
(KeyCode::Char('c'), KeyModifiers::NONE) => {
|
||||
self.controller.clear();
|
||||
self.status = "Conversation cleared".to_string();
|
||||
}
|
||||
(KeyCode::Enter, KeyModifiers::NONE)
|
||||
| (KeyCode::Char('i'), KeyModifiers::NONE) => {
|
||||
self.mode = InputMode::Editing;
|
||||
}
|
||||
(KeyCode::Up, KeyModifiers::NONE) => {
|
||||
self.scroll = self.scroll.saturating_add(1);
|
||||
}
|
||||
(KeyCode::Down, KeyModifiers::NONE) => {
|
||||
self.scroll = self.scroll.saturating_sub(1);
|
||||
}
|
||||
(KeyCode::Esc, KeyModifiers::NONE) => {
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
InputMode::Editing => match key.code {
|
||||
KeyCode::Esc if key.modifiers.is_empty() => {
|
||||
self.mode = InputMode::Normal;
|
||||
self.reset_status();
|
||||
}
|
||||
KeyCode::Enter if key.modifiers.contains(KeyModifiers::SHIFT) => {
|
||||
self.input_buffer_mut().insert_char('\n');
|
||||
}
|
||||
KeyCode::Enter if key.modifiers.is_empty() => {
|
||||
self.try_send_message().await?;
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
self.input_buffer_mut().insert_char('\n');
|
||||
}
|
||||
KeyCode::Char('j') if key.modifiers.contains(KeyModifiers::CONTROL) => {
|
||||
self.input_buffer_mut().insert_char('\n');
|
||||
}
|
||||
KeyCode::Backspace => {
|
||||
self.input_buffer_mut().backspace();
|
||||
}
|
||||
KeyCode::Delete => {
|
||||
self.input_buffer_mut().delete();
|
||||
}
|
||||
KeyCode::Left => {
|
||||
self.input_buffer_mut().move_left();
|
||||
}
|
||||
KeyCode::Right => {
|
||||
self.input_buffer_mut().move_right();
|
||||
}
|
||||
KeyCode::Home => {
|
||||
self.input_buffer_mut().move_home();
|
||||
}
|
||||
KeyCode::End => {
|
||||
self.input_buffer_mut().move_end();
|
||||
}
|
||||
KeyCode::Up => {
|
||||
self.input_buffer_mut().history_previous();
|
||||
}
|
||||
KeyCode::Down => {
|
||||
self.input_buffer_mut().history_next();
|
||||
}
|
||||
KeyCode::Char(c)
|
||||
if key.modifiers.is_empty()
|
||||
|| key.modifiers.contains(KeyModifiers::SHIFT) =>
|
||||
{
|
||||
self.input_buffer_mut().insert_char(c);
|
||||
}
|
||||
KeyCode::Tab => {
|
||||
self.input_buffer_mut().insert_tab();
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
InputMode::ProviderSelection => match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
if let Some(provider) = self.available_providers.get(self.selected_provider_index) {
|
||||
self.selected_provider = provider.clone();
|
||||
self.sync_selected_model_index(); // Update model selection based on new provider
|
||||
self.mode = InputMode::ModelSelection;
|
||||
}
|
||||
}
|
||||
KeyCode::Up => {
|
||||
if self.selected_provider_index > 0 {
|
||||
self.selected_provider_index -= 1;
|
||||
}
|
||||
}
|
||||
KeyCode::Down => {
|
||||
if self.selected_provider_index + 1 < self.available_providers.len() {
|
||||
self.selected_provider_index += 1;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
InputMode::ModelSelection => match key.code {
|
||||
KeyCode::Esc => {
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
KeyCode::Enter => {
|
||||
if let Some(selected_model_idx) = self.selected_model {
|
||||
let filtered_models = self.models();
|
||||
if let Some(model) = filtered_models.get(selected_model_idx) {
|
||||
let model_id = model.id.clone();
|
||||
let model_name = model.name.clone();
|
||||
|
||||
self.controller.set_model(model_id.clone());
|
||||
self.status = format!("Using model: {}", model_name);
|
||||
// Save the selected provider and model to config
|
||||
self.controller.config_mut().general.default_model = Some(model_id.clone());
|
||||
self.controller.config_mut().general.default_provider = self.selected_provider.clone();
|
||||
match config::save_config(self.controller.config()) {
|
||||
Ok(_) => self.error = None,
|
||||
Err(err) => {
|
||||
self.error = Some(format!("Failed to save config: {}", err));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
KeyCode::Up => {
|
||||
if let Some(selected_model_idx) = self.selected_model {
|
||||
if selected_model_idx > 0 {
|
||||
self.selected_model = Some(selected_model_idx - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
KeyCode::Down => {
|
||||
if let Some(selected_model_idx) = self.selected_model {
|
||||
if selected_model_idx + 1 < self.models().len() {
|
||||
self.selected_model = Some(selected_model_idx + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
InputMode::Help => match key.code {
|
||||
KeyCode::Esc | KeyCode::Enter => {
|
||||
self.mode = InputMode::Normal;
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
},
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(AppState::Running)
|
||||
}
|
||||
|
||||
pub fn handle_session_event(&mut self, event: SessionEvent) -> Result<()> {
|
||||
match event {
|
||||
SessionEvent::StreamChunk {
|
||||
message_id,
|
||||
response,
|
||||
} => {
|
||||
self.controller.apply_stream_chunk(message_id, &response)?;
|
||||
if response.is_final {
|
||||
self.streaming.remove(&message_id);
|
||||
self.status = "Response complete".to_string();
|
||||
}
|
||||
}
|
||||
SessionEvent::StreamError { message } => {
|
||||
self.error = Some(message);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn reset_status(&mut self) {
|
||||
self.status = "Ready".to_string();
|
||||
self.error = None;
|
||||
}
|
||||
|
||||
async fn refresh_models(&mut self) -> Result<()> {
|
||||
let config_model_name = self.controller.config().general.default_model.clone();
|
||||
let config_model_provider = self.controller.config().general.default_provider.clone();
|
||||
|
||||
let all_models = self.controller.models(true).await?;
|
||||
if all_models.is_empty() {
|
||||
self.error = Some("No models available".to_string());
|
||||
} else {
|
||||
self.models = all_models;
|
||||
|
||||
// Populate available_providers
|
||||
let mut providers = self.models.iter().map(|m| m.provider.clone()).collect::<HashSet<_>>();
|
||||
self.available_providers = providers.into_iter().collect();
|
||||
self.available_providers.sort();
|
||||
|
||||
// Set selected_provider based on config, or default to "ollama" if not found
|
||||
self.selected_provider = self.available_providers.iter()
|
||||
.find(|&p| p == &config_model_provider)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| "ollama".to_string());
|
||||
self.selected_provider_index = self.available_providers.iter()
|
||||
.position(|p| p == &self.selected_provider)
|
||||
.unwrap_or(0);
|
||||
|
||||
self.controller.ensure_default_model(&self.models);
|
||||
self.sync_selected_model_index();
|
||||
|
||||
let current_model_name = self.controller.selected_model().to_string();
|
||||
let current_model_provider = self.controller.config().general.default_provider.clone();
|
||||
|
||||
if config_model_name.as_deref() != Some(¤t_model_name) || config_model_provider != current_model_provider {
|
||||
if let Err(err) = config::save_config(self.controller.config()) {
|
||||
self.error = Some(format!("Failed to save config: {err}"));
|
||||
} else {
|
||||
self.error = None;
|
||||
}
|
||||
}
|
||||
self.status = format!("Loaded {} models", self.models.len());
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn try_send_message(&mut self) -> Result<()> {
|
||||
let content = self.controller.input_buffer().text().trim().to_string();
|
||||
if content.is_empty() {
|
||||
self.error = Some("Cannot send empty message".to_string());
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let message = self.controller.input_buffer_mut().commit_to_history();
|
||||
let mut parameters = ChatParameters::default();
|
||||
parameters.stream = self.controller.config().general.enable_streaming;
|
||||
|
||||
match self.controller.send_message(message, parameters).await {
|
||||
Ok(SessionOutcome::Complete(_response)) => {
|
||||
self.status = "Response received".to_string();
|
||||
self.error = None;
|
||||
Ok(())
|
||||
}
|
||||
Ok(SessionOutcome::Streaming {
|
||||
response_id,
|
||||
stream,
|
||||
}) => {
|
||||
self.spawn_stream(response_id, stream);
|
||||
match self
|
||||
.controller
|
||||
.mark_stream_placeholder(response_id, "Loading...")
|
||||
{
|
||||
Ok(_) => self.error = None,
|
||||
Err(err) => {
|
||||
self.error = Some(format!("Could not set loading placeholder: {}", err));
|
||||
}
|
||||
}
|
||||
self.status = "Waiting for response...".to_string();
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => {
|
||||
let message = err.to_string();
|
||||
if message.to_lowercase().contains("not found") {
|
||||
self.error = Some(
|
||||
"Model not available. Press 'm' to pick another installed model."
|
||||
.to_string(),
|
||||
);
|
||||
self.status = "Model unavailable".to_string();
|
||||
let _ = self.refresh_models().await;
|
||||
self.mode = InputMode::ProviderSelection;
|
||||
} else {
|
||||
self.error = Some(message);
|
||||
self.status = "Send failed".to_string();
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn sync_selected_model_index(&mut self) {
|
||||
let current_model_id = self.controller.selected_model().to_string();
|
||||
let filtered_models: Vec<&ModelInfo> = self.models.iter()
|
||||
.filter(|m| m.provider == self.selected_provider)
|
||||
.collect();
|
||||
|
||||
if filtered_models.is_empty() {
|
||||
self.selected_model = None;
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(idx) = filtered_models
|
||||
.iter()
|
||||
.position(|m| m.id == current_model_id)
|
||||
{
|
||||
self.selected_model = Some(idx);
|
||||
} else {
|
||||
// If the current model is not in the filtered list, select the first one
|
||||
self.selected_model = Some(0);
|
||||
if let Some(model) = filtered_models.get(0) {
|
||||
self.controller.set_model(model.id.clone());
|
||||
// Also update the config with the new model and provider
|
||||
self.controller.config_mut().general.default_model = Some(model.id.clone());
|
||||
self.controller.config_mut().general.default_provider = self.selected_provider.clone();
|
||||
if let Err(err) = config::save_config(self.controller.config()) {
|
||||
self.error = Some(format!("Failed to save config: {err}"));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_stream(&mut self, message_id: Uuid, mut stream: owlen_core::provider::ChatStream) {
|
||||
let sender = self.session_tx.clone();
|
||||
self.streaming.insert(message_id);
|
||||
|
||||
tokio::spawn(async move {
|
||||
use futures_util::StreamExt;
|
||||
|
||||
while let Some(item) = stream.next().await {
|
||||
match item {
|
||||
Ok(response) => {
|
||||
if sender
|
||||
.send(SessionEvent::StreamChunk {
|
||||
message_id,
|
||||
response,
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = sender.send(SessionEvent::StreamError {
|
||||
message: e.to_string(),
|
||||
});
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user