Files
owlen/crates/core/agent/src/session.rs

338 lines
10 KiB
Rust

//! Session state and history management.
//!
//! This module provides tools for tracking conversation history, capturing
//! usage statistics, and managing session checkpoints for persistence and rewind.
use color_eyre::eyre::{Result, eyre};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration, SystemTime};
/// Statistics for a single chat session.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SessionStats {
/// The time when the session started.
pub start_time: SystemTime,
/// Total number of messages exchanged.
pub total_messages: usize,
/// Total number of tools executed by the agent.
pub total_tool_calls: usize,
/// Total wall-clock time spent in the session.
pub total_duration: Duration,
/// Rough estimate of the total tokens used.
pub estimated_tokens: usize,
}
impl SessionStats {
/// Creates a new `SessionStats` instance with zeroed values.
pub fn new() -> Self {
Self {
start_time: SystemTime::now(),
total_messages: 0,
total_tool_calls: 0,
total_duration: Duration::ZERO,
estimated_tokens: 0,
}
}
/// Records a new message in the statistics.
pub fn record_message(&mut self, tokens: usize, duration: Duration) {
self.total_messages += 1;
self.estimated_tokens += tokens;
self.total_duration += duration;
}
/// Increments the tool call counter.
pub fn record_tool_call(&mut self) {
self.total_tool_calls += 1;
}
/// Formats a duration into a human-readable string.
pub fn format_duration(d: Duration) -> String {
let secs = d.as_secs();
if secs < 60 {
format!("{}s", secs)
} else if secs < 3600 {
format!("{}m {}s", secs / 60, secs % 60)
} else {
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
}
}
}
impl Default for SessionStats {
fn default() -> Self {
Self::new()
}
}
/// In-memory history of the current session.
#[derive(Debug, Clone)]
pub struct SessionHistory {
/// List of prompts provided by the user.
pub user_prompts: Vec<String>,
/// List of responses generated by the assistant.
pub assistant_responses: Vec<String>,
/// Chronological log of all tool calls made.
pub tool_calls: Vec<ToolCallRecord>,
}
/// Record of a single tool execution.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallRecord {
/// Name of the tool that was called.
pub tool_name: String,
/// JSON-encoded arguments provided to the tool.
pub arguments: String,
/// Output produced by the tool.
pub result: String,
/// Whether the tool execution was successful.
pub success: bool,
}
impl SessionHistory {
/// Creates a new, empty `SessionHistory`.
pub fn new() -> Self {
Self {
user_prompts: Vec::new(),
assistant_responses: Vec::new(),
tool_calls: Vec::new(),
}
}
/// Appends a user message to history.
pub fn add_user_message(&mut self, message: String) {
self.user_prompts.push(message);
}
/// Appends an assistant response to history.
pub fn add_assistant_message(&mut self, message: String) {
self.assistant_responses.push(message);
}
/// Appends a tool call record to history.
pub fn add_tool_call(&mut self, record: ToolCallRecord) {
self.tool_calls.push(record);
}
/// Clears all stored history.
pub fn clear(&mut self) {
self.user_prompts.clear();
self.assistant_responses.clear();
self.tool_calls.clear();
}
}
impl Default for SessionHistory {
fn default() -> Self {
Self::new()
}
}
/// Represents a file modification with before/after content.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileDiff {
/// Absolute path to the file.
pub path: PathBuf,
/// Content of the file before modification.
pub before: String,
/// Content of the file after modification.
pub after: String,
/// When the modification occurred.
pub timestamp: SystemTime,
}
impl FileDiff {
/// Creates a new `FileDiff`.
pub fn new(path: PathBuf, before: String, after: String) -> Self {
Self {
path,
before,
after,
timestamp: SystemTime::now(),
}
}
}
/// A checkpoint captures the full state of a session at a point in time.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
/// Unique identifier for the checkpoint.
pub id: String,
/// When the checkpoint was created.
pub timestamp: SystemTime,
/// Session statistics at the time of checkpoint.
pub stats: SessionStats,
/// History of user prompts.
pub user_prompts: Vec<String>,
/// History of assistant responses.
pub assistant_responses: Vec<String>,
/// History of tool calls.
pub tool_calls: Vec<ToolCallRecord>,
/// List of file modifications made during the session.
pub file_diffs: Vec<FileDiff>,
}
impl Checkpoint {
/// Creates a new checkpoint from the current session state.
pub fn new(
id: String,
stats: SessionStats,
history: &SessionHistory,
file_diffs: Vec<FileDiff>,
) -> Self {
Self {
id,
timestamp: SystemTime::now(),
stats,
user_prompts: history.user_prompts.clone(),
assistant_responses: history.assistant_responses.clone(),
tool_calls: history.tool_calls.clone(),
file_diffs,
}
}
/// Saves the checkpoint to a JSON file on disk.
pub fn save(&self, checkpoint_dir: &Path) -> Result<()> {
fs::create_dir_all(checkpoint_dir)?;
let path = checkpoint_dir.join(format!("{}.json", self.id));
let content = serde_json::to_string_pretty(self)?;
fs::write(path, content)?;
Ok(())
}
/// Loads a checkpoint from disk by ID.
pub fn load(checkpoint_dir: &Path, id: &str) -> Result<Self> {
let path = checkpoint_dir.join(format!("{}.json", id));
let content = fs::read_to_string(&path)
.map_err(|e| eyre!("Failed to read checkpoint: {}", e))?;
let checkpoint: Checkpoint = serde_json::from_str(&content)
.map_err(|e| eyre!("Failed to parse checkpoint: {}", e))?;
Ok(checkpoint)
}
/// Lists all available checkpoint IDs in the given directory.
pub fn list(checkpoint_dir: &Path) -> Result<Vec<String>> {
if !checkpoint_dir.exists() {
return Ok(Vec::new());
}
let mut checkpoints = Vec::new();
for entry in fs::read_dir(checkpoint_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json")
&& let Some(stem) = path.file_stem().and_then(|s| s.to_str())
{
checkpoints.push(stem.to_string());
}
}
// Sort by checkpoint ID (which includes timestamp)
checkpoints.sort();
Ok(checkpoints)
}
}
/// Manages the creation and restoration of session checkpoints.
pub struct CheckpointManager {
checkpoint_dir: PathBuf,
file_snapshots: HashMap<PathBuf, String>,
}
impl CheckpointManager {
/// Creates a new `CheckpointManager` pointing to the specified directory.
pub fn new(checkpoint_dir: PathBuf) -> Self {
Self {
checkpoint_dir,
file_snapshots: HashMap::new(),
}
}
/// Snapshots a file's current content before modification to track changes.
pub fn snapshot_file(&mut self, path: &Path) -> Result<()> {
if !self.file_snapshots.contains_key(path) {
let content = fs::read_to_string(path).unwrap_or_default();
self.file_snapshots.insert(path.to_path_buf(), content);
}
Ok(())
}
/// Creates a `FileDiff` if the file has been modified since it was snapshotted.
pub fn create_diff(&self, path: &Path) -> Result<Option<FileDiff>> {
if let Some(before) = self.file_snapshots.get(path) {
let after = fs::read_to_string(path).unwrap_or_default();
if before != &after {
Ok(Some(FileDiff::new(
path.to_path_buf(),
before.clone(),
after,
)))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
/// Returns all file modifications tracked since the last checkpoint.
pub fn get_all_diffs(&self) -> Result<Vec<FileDiff>> {
let mut diffs = Vec::new();
for (path, before) in &self.file_snapshots {
let after = fs::read_to_string(path).unwrap_or_default();
if before != &after {
diffs.push(FileDiff::new(path.clone(), before.clone(), after));
}
}
Ok(diffs)
}
/// Clears all internal file snapshots.
pub fn clear_snapshots(&mut self) {
self.file_snapshots.clear();
}
/// Saves the current session state as a new checkpoint.
pub fn save_checkpoint(
&mut self,
id: String,
stats: SessionStats,
history: &SessionHistory,
) -> Result<Checkpoint> {
let file_diffs = self.get_all_diffs()?;
let checkpoint = Checkpoint::new(id, stats, history, file_diffs);
checkpoint.save(&self.checkpoint_dir)?;
self.clear_snapshots();
Ok(checkpoint)
}
/// Loads a checkpoint by ID.
pub fn load_checkpoint(&self, id: &str) -> Result<Checkpoint> {
Checkpoint::load(&self.checkpoint_dir, id)
}
/// Lists all available checkpoints.
pub fn list_checkpoints(&self) -> Result<Vec<String>> {
Checkpoint::list(&self.checkpoint_dir)
}
/// Rewinds the local filesystem to the state captured in the specified checkpoint.
///
/// Returns a list of paths that were restored.
pub fn rewind_to(&self, checkpoint_id: &str) -> Result<Vec<PathBuf>> {
let checkpoint = self.load_checkpoint(checkpoint_id)?;
let mut restored_files = Vec::new();
// Restore files from diffs (revert to 'before' state)
for diff in &checkpoint.file_diffs {
fs::write(&diff.path, &diff.before)?;
restored_files.push(diff.path.clone());
}
Ok(restored_files)
}
}