338 lines
10 KiB
Rust
338 lines
10 KiB
Rust
//! Session state and history management.
|
|
//!
|
|
//! This module provides tools for tracking conversation history, capturing
|
|
//! usage statistics, and managing session checkpoints for persistence and rewind.
|
|
|
|
use color_eyre::eyre::{Result, eyre};
|
|
use serde::{Deserialize, Serialize};
|
|
use std::collections::HashMap;
|
|
use std::fs;
|
|
use std::path::{Path, PathBuf};
|
|
use std::time::{Duration, SystemTime};
|
|
|
|
/// Statistics for a single chat session.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct SessionStats {
|
|
/// The time when the session started.
|
|
pub start_time: SystemTime,
|
|
/// Total number of messages exchanged.
|
|
pub total_messages: usize,
|
|
/// Total number of tools executed by the agent.
|
|
pub total_tool_calls: usize,
|
|
/// Total wall-clock time spent in the session.
|
|
pub total_duration: Duration,
|
|
/// Rough estimate of the total tokens used.
|
|
pub estimated_tokens: usize,
|
|
}
|
|
|
|
impl SessionStats {
|
|
/// Creates a new `SessionStats` instance with zeroed values.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
start_time: SystemTime::now(),
|
|
total_messages: 0,
|
|
total_tool_calls: 0,
|
|
total_duration: Duration::ZERO,
|
|
estimated_tokens: 0,
|
|
}
|
|
}
|
|
|
|
/// Records a new message in the statistics.
|
|
pub fn record_message(&mut self, tokens: usize, duration: Duration) {
|
|
self.total_messages += 1;
|
|
self.estimated_tokens += tokens;
|
|
self.total_duration += duration;
|
|
}
|
|
|
|
/// Increments the tool call counter.
|
|
pub fn record_tool_call(&mut self) {
|
|
self.total_tool_calls += 1;
|
|
}
|
|
|
|
/// Formats a duration into a human-readable string.
|
|
pub fn format_duration(d: Duration) -> String {
|
|
let secs = d.as_secs();
|
|
if secs < 60 {
|
|
format!("{}s", secs)
|
|
} else if secs < 3600 {
|
|
format!("{}m {}s", secs / 60, secs % 60)
|
|
} else {
|
|
format!("{}h {}m", secs / 3600, (secs % 3600) / 60)
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Default for SessionStats {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// In-memory history of the current session.
|
|
#[derive(Debug, Clone)]
|
|
pub struct SessionHistory {
|
|
/// List of prompts provided by the user.
|
|
pub user_prompts: Vec<String>,
|
|
/// List of responses generated by the assistant.
|
|
pub assistant_responses: Vec<String>,
|
|
/// Chronological log of all tool calls made.
|
|
pub tool_calls: Vec<ToolCallRecord>,
|
|
}
|
|
|
|
/// Record of a single tool execution.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct ToolCallRecord {
|
|
/// Name of the tool that was called.
|
|
pub tool_name: String,
|
|
/// JSON-encoded arguments provided to the tool.
|
|
pub arguments: String,
|
|
/// Output produced by the tool.
|
|
pub result: String,
|
|
/// Whether the tool execution was successful.
|
|
pub success: bool,
|
|
}
|
|
|
|
impl SessionHistory {
|
|
/// Creates a new, empty `SessionHistory`.
|
|
pub fn new() -> Self {
|
|
Self {
|
|
user_prompts: Vec::new(),
|
|
assistant_responses: Vec::new(),
|
|
tool_calls: Vec::new(),
|
|
}
|
|
}
|
|
|
|
/// Appends a user message to history.
|
|
pub fn add_user_message(&mut self, message: String) {
|
|
self.user_prompts.push(message);
|
|
}
|
|
|
|
/// Appends an assistant response to history.
|
|
pub fn add_assistant_message(&mut self, message: String) {
|
|
self.assistant_responses.push(message);
|
|
}
|
|
|
|
/// Appends a tool call record to history.
|
|
pub fn add_tool_call(&mut self, record: ToolCallRecord) {
|
|
self.tool_calls.push(record);
|
|
}
|
|
|
|
/// Clears all stored history.
|
|
pub fn clear(&mut self) {
|
|
self.user_prompts.clear();
|
|
self.assistant_responses.clear();
|
|
self.tool_calls.clear();
|
|
}
|
|
}
|
|
|
|
impl Default for SessionHistory {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Represents a file modification with before/after content.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct FileDiff {
|
|
/// Absolute path to the file.
|
|
pub path: PathBuf,
|
|
/// Content of the file before modification.
|
|
pub before: String,
|
|
/// Content of the file after modification.
|
|
pub after: String,
|
|
/// When the modification occurred.
|
|
pub timestamp: SystemTime,
|
|
}
|
|
|
|
impl FileDiff {
|
|
/// Creates a new `FileDiff`.
|
|
pub fn new(path: PathBuf, before: String, after: String) -> Self {
|
|
Self {
|
|
path,
|
|
before,
|
|
after,
|
|
timestamp: SystemTime::now(),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// A checkpoint captures the full state of a session at a point in time.
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
pub struct Checkpoint {
|
|
/// Unique identifier for the checkpoint.
|
|
pub id: String,
|
|
/// When the checkpoint was created.
|
|
pub timestamp: SystemTime,
|
|
/// Session statistics at the time of checkpoint.
|
|
pub stats: SessionStats,
|
|
/// History of user prompts.
|
|
pub user_prompts: Vec<String>,
|
|
/// History of assistant responses.
|
|
pub assistant_responses: Vec<String>,
|
|
/// History of tool calls.
|
|
pub tool_calls: Vec<ToolCallRecord>,
|
|
/// List of file modifications made during the session.
|
|
pub file_diffs: Vec<FileDiff>,
|
|
}
|
|
|
|
impl Checkpoint {
|
|
/// Creates a new checkpoint from the current session state.
|
|
pub fn new(
|
|
id: String,
|
|
stats: SessionStats,
|
|
history: &SessionHistory,
|
|
file_diffs: Vec<FileDiff>,
|
|
) -> Self {
|
|
Self {
|
|
id,
|
|
timestamp: SystemTime::now(),
|
|
stats,
|
|
user_prompts: history.user_prompts.clone(),
|
|
assistant_responses: history.assistant_responses.clone(),
|
|
tool_calls: history.tool_calls.clone(),
|
|
file_diffs,
|
|
}
|
|
}
|
|
|
|
/// Saves the checkpoint to a JSON file on disk.
|
|
pub fn save(&self, checkpoint_dir: &Path) -> Result<()> {
|
|
fs::create_dir_all(checkpoint_dir)?;
|
|
let path = checkpoint_dir.join(format!("{}.json", self.id));
|
|
let content = serde_json::to_string_pretty(self)?;
|
|
fs::write(path, content)?;
|
|
Ok(())
|
|
}
|
|
|
|
/// Loads a checkpoint from disk by ID.
|
|
pub fn load(checkpoint_dir: &Path, id: &str) -> Result<Self> {
|
|
let path = checkpoint_dir.join(format!("{}.json", id));
|
|
let content = fs::read_to_string(&path)
|
|
.map_err(|e| eyre!("Failed to read checkpoint: {}", e))?;
|
|
let checkpoint: Checkpoint = serde_json::from_str(&content)
|
|
.map_err(|e| eyre!("Failed to parse checkpoint: {}", e))?;
|
|
Ok(checkpoint)
|
|
}
|
|
|
|
/// Lists all available checkpoint IDs in the given directory.
|
|
pub fn list(checkpoint_dir: &Path) -> Result<Vec<String>> {
|
|
if !checkpoint_dir.exists() {
|
|
return Ok(Vec::new());
|
|
}
|
|
|
|
let mut checkpoints = Vec::new();
|
|
for entry in fs::read_dir(checkpoint_dir)? {
|
|
let entry = entry?;
|
|
let path = entry.path();
|
|
if path.extension().and_then(|s| s.to_str()) == Some("json")
|
|
&& let Some(stem) = path.file_stem().and_then(|s| s.to_str())
|
|
{
|
|
checkpoints.push(stem.to_string());
|
|
}
|
|
}
|
|
|
|
// Sort by checkpoint ID (which includes timestamp)
|
|
checkpoints.sort();
|
|
Ok(checkpoints)
|
|
}
|
|
}
|
|
|
|
/// Manages the creation and restoration of session checkpoints.
|
|
pub struct CheckpointManager {
|
|
checkpoint_dir: PathBuf,
|
|
file_snapshots: HashMap<PathBuf, String>,
|
|
}
|
|
|
|
impl CheckpointManager {
|
|
/// Creates a new `CheckpointManager` pointing to the specified directory.
|
|
pub fn new(checkpoint_dir: PathBuf) -> Self {
|
|
Self {
|
|
checkpoint_dir,
|
|
file_snapshots: HashMap::new(),
|
|
}
|
|
}
|
|
|
|
/// Snapshots a file's current content before modification to track changes.
|
|
pub fn snapshot_file(&mut self, path: &Path) -> Result<()> {
|
|
if !self.file_snapshots.contains_key(path) {
|
|
let content = fs::read_to_string(path).unwrap_or_default();
|
|
self.file_snapshots.insert(path.to_path_buf(), content);
|
|
}
|
|
Ok(())
|
|
}
|
|
|
|
/// Creates a `FileDiff` if the file has been modified since it was snapshotted.
|
|
pub fn create_diff(&self, path: &Path) -> Result<Option<FileDiff>> {
|
|
if let Some(before) = self.file_snapshots.get(path) {
|
|
let after = fs::read_to_string(path).unwrap_or_default();
|
|
if before != &after {
|
|
Ok(Some(FileDiff::new(
|
|
path.to_path_buf(),
|
|
before.clone(),
|
|
after,
|
|
)))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
/// Returns all file modifications tracked since the last checkpoint.
|
|
pub fn get_all_diffs(&self) -> Result<Vec<FileDiff>> {
|
|
let mut diffs = Vec::new();
|
|
for (path, before) in &self.file_snapshots {
|
|
let after = fs::read_to_string(path).unwrap_or_default();
|
|
if before != &after {
|
|
diffs.push(FileDiff::new(path.clone(), before.clone(), after));
|
|
}
|
|
}
|
|
Ok(diffs)
|
|
}
|
|
|
|
/// Clears all internal file snapshots.
|
|
pub fn clear_snapshots(&mut self) {
|
|
self.file_snapshots.clear();
|
|
}
|
|
|
|
/// Saves the current session state as a new checkpoint.
|
|
pub fn save_checkpoint(
|
|
&mut self,
|
|
id: String,
|
|
stats: SessionStats,
|
|
history: &SessionHistory,
|
|
) -> Result<Checkpoint> {
|
|
let file_diffs = self.get_all_diffs()?;
|
|
let checkpoint = Checkpoint::new(id, stats, history, file_diffs);
|
|
checkpoint.save(&self.checkpoint_dir)?;
|
|
self.clear_snapshots();
|
|
Ok(checkpoint)
|
|
}
|
|
|
|
/// Loads a checkpoint by ID.
|
|
pub fn load_checkpoint(&self, id: &str) -> Result<Checkpoint> {
|
|
Checkpoint::load(&self.checkpoint_dir, id)
|
|
}
|
|
|
|
/// Lists all available checkpoints.
|
|
pub fn list_checkpoints(&self) -> Result<Vec<String>> {
|
|
Checkpoint::list(&self.checkpoint_dir)
|
|
}
|
|
|
|
/// Rewinds the local filesystem to the state captured in the specified checkpoint.
|
|
///
|
|
/// Returns a list of paths that were restored.
|
|
pub fn rewind_to(&self, checkpoint_id: &str) -> Result<Vec<PathBuf>> {
|
|
let checkpoint = self.load_checkpoint(checkpoint_id)?;
|
|
let mut restored_files = Vec::new();
|
|
|
|
// Restore files from diffs (revert to 'before' state)
|
|
for diff in &checkpoint.file_diffs {
|
|
fs::write(&diff.path, &diff.before)?;
|
|
restored_files.push(diff.path.clone());
|
|
}
|
|
|
|
Ok(restored_files)
|
|
}
|
|
}
|