169 lines
6.0 KiB
Rust
169 lines
6.0 KiB
Rust
use crate::messages::{Message, UserAction, AgentResponse};
|
|
use crate::state::AppState;
|
|
use tokio::sync::{mpsc, Mutex};
|
|
use std::sync::Arc;
|
|
use llm_core::{LlmProvider, ChatMessage, ChatOptions};
|
|
use futures::StreamExt;
|
|
|
|
/// The main background task that handles logic, API calls, and state updates.
|
|
pub async fn run_engine_loop(
|
|
mut rx: mpsc::Receiver<Message>,
|
|
tx_ui: mpsc::Sender<Message>,
|
|
client: Arc<dyn LlmProvider>,
|
|
state: Arc<Mutex<AppState>>,
|
|
) {
|
|
while let Some(msg) = rx.recv().await {
|
|
match msg {
|
|
Message::UserAction(UserAction::Input(text)) => {
|
|
// Update history with user message
|
|
let messages = {
|
|
let mut guard = state.lock().await;
|
|
guard.add_message(ChatMessage::user(text.clone()));
|
|
guard.messages.clone()
|
|
};
|
|
|
|
// Use default options for now
|
|
let options = ChatOptions::default();
|
|
|
|
match client.chat_stream(&messages, &options, None).await {
|
|
Ok(mut stream) => {
|
|
let mut full_response = String::new();
|
|
|
|
while let Some(result) = stream.next().await {
|
|
match result {
|
|
Ok(chunk) => {
|
|
if let Some(content) = chunk.content {
|
|
full_response.push_str(&content);
|
|
if let Err(e) = tx_ui.send(Message::AgentResponse(AgentResponse::Token(content))).await {
|
|
eprintln!("Failed to send token to UI: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
let _ = tx_ui.send(Message::AgentResponse(AgentResponse::Error(e.to_string()))).await;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Add assistant response to history
|
|
{
|
|
let mut guard = state.lock().await;
|
|
guard.add_message(ChatMessage::assistant(full_response));
|
|
}
|
|
|
|
let _ = tx_ui.send(Message::AgentResponse(AgentResponse::Complete)).await;
|
|
}
|
|
Err(e) => {
|
|
let _ = tx_ui.send(Message::AgentResponse(AgentResponse::Error(e.to_string()))).await;
|
|
}
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::messages::{Message, UserAction, AgentResponse};
|
|
use llm_core::{LlmProvider, LlmError, ChatMessage, ChatOptions, Tool, ChunkStream, StreamChunk};
|
|
use async_trait::async_trait;
|
|
use futures::stream;
|
|
|
|
struct MockProvider;
|
|
|
|
#[async_trait]
|
|
impl LlmProvider for MockProvider {
|
|
fn name(&self) -> &str { "mock" }
|
|
fn model(&self) -> &str { "mock-model" }
|
|
|
|
async fn chat_stream(
|
|
&self,
|
|
_messages: &[ChatMessage],
|
|
_options: &ChatOptions,
|
|
_tools: Option<&[Tool]>,
|
|
) -> Result<ChunkStream, LlmError> {
|
|
let chunks = vec![
|
|
Ok(StreamChunk { content: Some("Hello".to_string()), tool_calls: None, done: false, usage: None }),
|
|
Ok(StreamChunk { content: Some(" World".to_string()), tool_calls: None, done: true, usage: None }),
|
|
];
|
|
Ok(Box::pin(stream::iter(chunks)))
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_engine_streaming() {
|
|
let (tx_in, rx_in) = mpsc::channel(1);
|
|
let (tx_out, mut rx_out) = mpsc::channel(10);
|
|
|
|
let client = Arc::new(MockProvider);
|
|
let state = Arc::new(Mutex::new(AppState::new()));
|
|
|
|
// Spawn the engine loop
|
|
tokio::spawn(async move {
|
|
run_engine_loop(rx_in, tx_out, client, state).await;
|
|
});
|
|
|
|
// Send a message
|
|
tx_in.send(Message::UserAction(UserAction::Input("Hi".to_string()))).await.unwrap();
|
|
|
|
// Verify streaming responses
|
|
if let Some(Message::AgentResponse(AgentResponse::Token(s))) = rx_out.recv().await {
|
|
assert_eq!(s, "Hello");
|
|
} else {
|
|
panic!("Expected Token(Hello)");
|
|
}
|
|
|
|
if let Some(Message::AgentResponse(AgentResponse::Token(s))) = rx_out.recv().await {
|
|
assert_eq!(s, " World");
|
|
} else {
|
|
panic!("Expected Token( World)");
|
|
}
|
|
|
|
if let Some(Message::AgentResponse(AgentResponse::Complete)) = rx_out.recv().await {
|
|
// OK
|
|
} else {
|
|
panic!("Expected Complete");
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_engine_state_update() {
|
|
let (tx_in, rx_in) = mpsc::channel(1);
|
|
let (tx_out, mut rx_out) = mpsc::channel(10);
|
|
|
|
let client = Arc::new(MockProvider);
|
|
let state = Arc::new(Mutex::new(AppState::new()));
|
|
let state_clone = state.clone();
|
|
|
|
// Spawn the engine loop
|
|
tokio::spawn(async move {
|
|
run_engine_loop(rx_in, tx_out, client, state_clone).await;
|
|
});
|
|
|
|
// Send a message
|
|
tx_in.send(Message::UserAction(UserAction::Input("Hi".to_string()))).await.unwrap();
|
|
|
|
// Wait for completion
|
|
while let Some(msg) = rx_out.recv().await {
|
|
if let Message::AgentResponse(AgentResponse::Complete) = msg {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Verify state
|
|
let guard = state.lock().await;
|
|
assert_eq!(guard.messages.len(), 2); // User + Assistant
|
|
match &guard.messages[0].role {
|
|
llm_core::Role::User => {},
|
|
_ => panic!("First message should be User"),
|
|
}
|
|
match &guard.messages[1].role {
|
|
llm_core::Role::Assistant => {},
|
|
_ => panic!("Second message should be Assistant"),
|
|
}
|
|
}
|
|
}
|