249 lines
7.2 KiB
Rust
249 lines
7.2 KiB
Rust
//! Slash command parsing for chat input.
|
|
//!
|
|
//! Provides lightweight handling for inline commands such as `/summarize`
|
|
//! and `/testplan`. The parser returns owned data so callers can prepare
|
|
//! requests immediately without additional lifetime juggling.
|
|
|
|
use std::collections::HashMap;
|
|
use std::fmt;
|
|
use std::str::FromStr;
|
|
use std::sync::{OnceLock, RwLock};
|
|
|
|
/// Supported slash commands.
|
|
#[derive(Debug, Clone)]
|
|
pub enum SlashCommand {
|
|
Summarize { count: Option<usize> },
|
|
Explain { snippet: String },
|
|
Refactor { path: String },
|
|
TestPlan,
|
|
Compact,
|
|
McpTool { server: String, tool: String },
|
|
}
|
|
|
|
/// Errors emitted when parsing invalid slash input.
|
|
#[derive(Debug)]
|
|
pub enum SlashError {
|
|
UnknownCommand(String),
|
|
Message(String),
|
|
}
|
|
|
|
impl fmt::Display for SlashError {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
SlashError::UnknownCommand(cmd) => write!(f, "unknown slash command: {cmd}"),
|
|
SlashError::Message(msg) => f.write_str(msg),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl std::error::Error for SlashError {}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct McpSlashCommand {
|
|
pub server: String,
|
|
pub tool: String,
|
|
pub keyword: String,
|
|
pub description: Option<String>,
|
|
}
|
|
|
|
impl McpSlashCommand {
|
|
pub fn new(
|
|
server: impl Into<String>,
|
|
tool: impl Into<String>,
|
|
description: Option<String>,
|
|
) -> Self {
|
|
let server = server.into();
|
|
let tool = tool.into();
|
|
let keyword = format!(
|
|
"mcp__{}__{}",
|
|
canonicalize_component(&server),
|
|
canonicalize_component(&tool)
|
|
);
|
|
Self {
|
|
server,
|
|
tool,
|
|
keyword,
|
|
description,
|
|
}
|
|
}
|
|
}
|
|
|
|
static MCP_COMMANDS: OnceLock<RwLock<HashMap<String, McpSlashCommand>>> = OnceLock::new();
|
|
|
|
fn dynamic_registry() -> &'static RwLock<HashMap<String, McpSlashCommand>> {
|
|
MCP_COMMANDS.get_or_init(|| RwLock::new(HashMap::new()))
|
|
}
|
|
|
|
pub fn set_mcp_commands(commands: impl IntoIterator<Item = McpSlashCommand>) {
|
|
let registry = dynamic_registry();
|
|
let mut guard = registry.write().expect("MCP command registry poisoned");
|
|
guard.clear();
|
|
for command in commands {
|
|
guard.insert(command.keyword.clone(), command);
|
|
}
|
|
}
|
|
|
|
fn find_mcp_command(keyword: &str) -> Option<McpSlashCommand> {
|
|
let registry = dynamic_registry();
|
|
let guard = registry.read().expect("MCP command registry poisoned");
|
|
guard.get(keyword).cloned()
|
|
}
|
|
|
|
fn canonicalize_component(input: &str) -> String {
|
|
let mut out = String::new();
|
|
let mut last_was_underscore = false;
|
|
for ch in input.chars() {
|
|
let mapped = if ch.is_ascii_alphanumeric() {
|
|
ch.to_ascii_lowercase()
|
|
} else {
|
|
'_'
|
|
};
|
|
if mapped == '_' {
|
|
if !last_was_underscore {
|
|
out.push('_');
|
|
last_was_underscore = true;
|
|
}
|
|
} else {
|
|
out.push(mapped);
|
|
last_was_underscore = false;
|
|
}
|
|
}
|
|
if out.is_empty() { "_".to_string() } else { out }
|
|
}
|
|
|
|
/// Attempt to parse a slash command from the provided input.
|
|
pub fn parse(input: &str) -> Result<Option<SlashCommand>, SlashError> {
|
|
let trimmed = input.trim();
|
|
if !trimmed.starts_with('/') {
|
|
return Ok(None);
|
|
}
|
|
|
|
let body = trimmed.trim_start_matches('/');
|
|
if body.is_empty() {
|
|
return Err(SlashError::Message("missing command name after '/'".into()));
|
|
}
|
|
|
|
let mut parts = body.split_whitespace();
|
|
let command = parts.next().unwrap();
|
|
let remainder = parts.collect::<Vec<_>>();
|
|
|
|
if let Some(dynamic) = find_mcp_command(command) {
|
|
if !remainder.is_empty() {
|
|
return Err(SlashError::Message(format!(
|
|
"/{} does not accept arguments",
|
|
dynamic.keyword
|
|
)));
|
|
}
|
|
return Ok(Some(SlashCommand::McpTool {
|
|
server: dynamic.server,
|
|
tool: dynamic.tool,
|
|
}));
|
|
}
|
|
|
|
let cmd = match command {
|
|
"summarize" => {
|
|
let count = remainder
|
|
.first()
|
|
.and_then(|value| usize::from_str(value).ok());
|
|
SlashCommand::Summarize { count }
|
|
}
|
|
"explain" => {
|
|
if remainder.is_empty() {
|
|
return Err(SlashError::Message(
|
|
"usage: /explain <code snippet or description>".into(),
|
|
));
|
|
}
|
|
SlashCommand::Explain {
|
|
snippet: remainder.join(" "),
|
|
}
|
|
}
|
|
"refactor" => {
|
|
if remainder.is_empty() {
|
|
return Err(SlashError::Message(
|
|
"usage: /refactor <relative/path/to/file>".into(),
|
|
));
|
|
}
|
|
SlashCommand::Refactor {
|
|
path: remainder.join(" "),
|
|
}
|
|
}
|
|
"testplan" => SlashCommand::TestPlan,
|
|
"compact" => SlashCommand::Compact,
|
|
other => return Err(SlashError::UnknownCommand(other.to_string())),
|
|
};
|
|
|
|
Ok(Some(cmd))
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn registry_guard() -> std::sync::MutexGuard<'static, ()> {
|
|
static GUARD: std::sync::OnceLock<std::sync::Mutex<()>> = std::sync::OnceLock::new();
|
|
GUARD
|
|
.get_or_init(|| std::sync::Mutex::new(()))
|
|
.lock()
|
|
.expect("registry test mutex poisoned")
|
|
}
|
|
|
|
#[test]
|
|
fn ignores_non_command_input() {
|
|
let result = parse("hello world").unwrap();
|
|
assert!(result.is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn parses_summarize_with_count() {
|
|
let command = parse("/summarize 10").unwrap().expect("expected command");
|
|
match command {
|
|
SlashCommand::Summarize { count } => assert_eq!(count, Some(10)),
|
|
other => panic!("unexpected command: {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn returns_error_for_unknown_command() {
|
|
let err = parse("/unknown").unwrap_err();
|
|
assert_eq!(err.to_string(), "unknown slash command: unknown");
|
|
}
|
|
|
|
#[test]
|
|
fn parses_registered_mcp_command() {
|
|
let _registry = registry_guard();
|
|
set_mcp_commands(Vec::new());
|
|
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
|
|
|
let command = parse("/mcp__github__list_prs")
|
|
.unwrap()
|
|
.expect("expected command");
|
|
match command {
|
|
SlashCommand::McpTool { server, tool } => {
|
|
assert_eq!(server, "github");
|
|
assert_eq!(tool, "list_prs");
|
|
}
|
|
other => panic!("unexpected command variant: {:?}", other),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn rejects_mcp_command_with_arguments() {
|
|
let _registry = registry_guard();
|
|
set_mcp_commands(Vec::new());
|
|
set_mcp_commands(vec![McpSlashCommand::new("github", "list_prs", None)]);
|
|
|
|
let err = parse("/mcp__github__list_prs extra").unwrap_err();
|
|
assert_eq!(
|
|
err.to_string(),
|
|
"/mcp__github__list_prs does not accept arguments"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn canonicalizes_mcp_command_components() {
|
|
set_mcp_commands(Vec::new());
|
|
let entry = McpSlashCommand::new("GitHub", "list/prs", None);
|
|
assert_eq!(entry.keyword, "mcp__github__list_prs");
|
|
}
|
|
}
|