Introduce `McpCommand` enum and handlers in `owlen-cli` to manage MCP server registrations, including adding, listing, and removing servers across configuration scopes. Add scoped configuration support (`ScopedMcpServer`, `McpConfigScope`) and OAuth token handling in core config, alongside runtime refresh of MCP servers. Implement toast notifications in the TUI (`render_toasts`, `Toast`, `ToastLevel`) and integrate async handling for session events. Update config loading, validation, and schema versioning to accommodate new MCP scopes and resources. Add `httpmock` as a dev dependency for testing.
508 lines
17 KiB
Rust
508 lines
17 KiB
Rust
use std::time::Duration as StdDuration;
|
|
|
|
use chrono::{DateTime, Duration, Utc};
|
|
use reqwest::Client;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::{Error, Result, config::McpOAuthConfig};
|
|
|
|
/// Persisted OAuth token set for MCP servers and providers.
|
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
|
|
pub struct OAuthToken {
|
|
/// Bearer access token returned by the authorization server.
|
|
pub access_token: String,
|
|
/// Optional refresh token if the provider issues one.
|
|
#[serde(default)]
|
|
pub refresh_token: Option<String>,
|
|
/// Absolute UTC expiration timestamp for the access token.
|
|
#[serde(default)]
|
|
pub expires_at: Option<DateTime<Utc>>,
|
|
/// Optional space-delimited scope string supplied by the provider.
|
|
#[serde(default)]
|
|
pub scope: Option<String>,
|
|
/// Token type reported by the provider (typically `Bearer`).
|
|
#[serde(default)]
|
|
pub token_type: Option<String>,
|
|
}
|
|
|
|
impl OAuthToken {
|
|
/// Returns `true` if the access token has expired at the provided instant.
|
|
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
|
matches!(self.expires_at, Some(expiry) if now >= expiry)
|
|
}
|
|
|
|
/// Returns `true` if the token will expire within the supplied duration window.
|
|
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
|
|
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
|
|
}
|
|
}
|
|
|
|
/// Active device-authorization session details returned by the authorization server.
|
|
#[derive(Debug, Clone)]
|
|
pub struct DeviceAuthorization {
|
|
pub device_code: String,
|
|
pub user_code: String,
|
|
pub verification_uri: String,
|
|
pub verification_uri_complete: Option<String>,
|
|
pub expires_at: DateTime<Utc>,
|
|
pub interval: StdDuration,
|
|
pub message: Option<String>,
|
|
}
|
|
|
|
impl DeviceAuthorization {
|
|
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
|
|
now >= self.expires_at
|
|
}
|
|
}
|
|
|
|
/// Result of polling the token endpoint during a device-authorization flow.
|
|
#[derive(Debug, Clone)]
|
|
pub enum DevicePollState {
|
|
Pending { retry_in: StdDuration },
|
|
Complete(OAuthToken),
|
|
}
|
|
|
|
pub struct OAuthClient {
|
|
http: Client,
|
|
config: McpOAuthConfig,
|
|
}
|
|
|
|
impl OAuthClient {
|
|
pub fn new(config: McpOAuthConfig) -> Result<Self> {
|
|
let http = Client::builder()
|
|
.user_agent("OwlenOAuth/1.0")
|
|
.build()
|
|
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
|
|
Ok(Self { http, config })
|
|
}
|
|
|
|
fn scope_value(&self) -> Option<String> {
|
|
if self.config.scopes.is_empty() {
|
|
None
|
|
} else {
|
|
Some(self.config.scopes.join(" "))
|
|
}
|
|
}
|
|
|
|
fn token_request_base(&self) -> Vec<(String, String)> {
|
|
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
|
|
if let Some(secret) = &self.config.client_secret {
|
|
params.push(("client_secret".to_string(), secret.clone()));
|
|
}
|
|
params
|
|
}
|
|
|
|
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
|
|
let device_url = self
|
|
.config
|
|
.device_authorization_url
|
|
.as_ref()
|
|
.ok_or_else(|| {
|
|
Error::Config("Device authorization endpoint is not configured.".to_string())
|
|
})?;
|
|
|
|
let mut params = self.token_request_base();
|
|
if let Some(scope) = self.scope_value() {
|
|
params.push(("scope".to_string(), scope));
|
|
}
|
|
|
|
let response = self
|
|
.http
|
|
.post(device_url)
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.map_err(|err| map_http_error("start device authorization", err))?;
|
|
|
|
let status = response.status();
|
|
let payload = response
|
|
.json::<DeviceAuthorizationResponse>()
|
|
.await
|
|
.map_err(|err| {
|
|
Error::Auth(format!(
|
|
"Failed to parse device authorization response (status {status}): {err}"
|
|
))
|
|
})?;
|
|
|
|
let expires_at =
|
|
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
|
|
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
|
|
|
|
Ok(DeviceAuthorization {
|
|
device_code: payload.device_code,
|
|
user_code: payload.user_code,
|
|
verification_uri: payload.verification_uri,
|
|
verification_uri_complete: payload.verification_uri_complete,
|
|
expires_at,
|
|
interval,
|
|
message: payload.message,
|
|
})
|
|
}
|
|
|
|
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
|
|
let mut params = self.token_request_base();
|
|
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
|
|
params.push(("device_code".to_string(), auth.device_code.clone()));
|
|
if let Some(scope) = self.scope_value() {
|
|
params.push(("scope".to_string(), scope));
|
|
}
|
|
|
|
let response = self
|
|
.http
|
|
.post(&self.config.token_url)
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.map_err(|err| map_http_error("poll device token", err))?;
|
|
|
|
let status = response.status();
|
|
let text = response
|
|
.text()
|
|
.await
|
|
.map_err(|err| map_http_error("read token response", err))?;
|
|
|
|
if status.is_success() {
|
|
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
|
Error::Auth(format!(
|
|
"Failed to parse OAuth token response: {err}; body: {text}"
|
|
))
|
|
})?;
|
|
return Ok(DevicePollState::Complete(oauth_token_from_response(
|
|
payload,
|
|
)));
|
|
}
|
|
|
|
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
|
OAuthErrorResponse {
|
|
error: "unknown_error".to_string(),
|
|
error_description: Some(text.clone()),
|
|
}
|
|
});
|
|
|
|
match error.error.as_str() {
|
|
"authorization_pending" => Ok(DevicePollState::Pending {
|
|
retry_in: auth.interval,
|
|
}),
|
|
"slow_down" => Ok(DevicePollState::Pending {
|
|
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
|
|
}),
|
|
"access_denied" => {
|
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
|
"User declined authorization".to_string()
|
|
})))
|
|
}
|
|
"expired_token" | "expired_device_code" => {
|
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
|
"Device authorization expired".to_string()
|
|
})))
|
|
}
|
|
other => Err(Error::Auth(
|
|
error
|
|
.error_description
|
|
.unwrap_or_else(|| format!("OAuth error: {other}")),
|
|
)),
|
|
}
|
|
}
|
|
|
|
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
|
|
let mut params = self.token_request_base();
|
|
params.push(("grant_type".to_string(), "refresh_token".to_string()));
|
|
params.push(("refresh_token".to_string(), refresh_token.to_string()));
|
|
if let Some(scope) = self.scope_value() {
|
|
params.push(("scope".to_string(), scope));
|
|
}
|
|
|
|
let response = self
|
|
.http
|
|
.post(&self.config.token_url)
|
|
.form(¶ms)
|
|
.send()
|
|
.await
|
|
.map_err(|err| map_http_error("refresh OAuth token", err))?;
|
|
|
|
let status = response.status();
|
|
let text = response
|
|
.text()
|
|
.await
|
|
.map_err(|err| map_http_error("read refresh response", err))?;
|
|
|
|
if status.is_success() {
|
|
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
|
|
Error::Auth(format!(
|
|
"Failed to parse OAuth refresh response: {err}; body: {text}"
|
|
))
|
|
})?;
|
|
Ok(oauth_token_from_response(payload))
|
|
} else {
|
|
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
|
|
OAuthErrorResponse {
|
|
error: "unknown_error".to_string(),
|
|
error_description: Some(text.clone()),
|
|
}
|
|
});
|
|
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
|
|
format!("OAuth token refresh failed: {}", error.error)
|
|
})))
|
|
}
|
|
}
|
|
}
|
|
|
|
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct DeviceAuthorizationResponse {
|
|
device_code: String,
|
|
user_code: String,
|
|
verification_uri: String,
|
|
#[serde(default)]
|
|
verification_uri_complete: Option<String>,
|
|
expires_in: u64,
|
|
#[serde(default)]
|
|
interval: Option<u64>,
|
|
#[serde(default)]
|
|
message: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct TokenResponse {
|
|
access_token: String,
|
|
#[serde(default)]
|
|
refresh_token: Option<String>,
|
|
#[serde(default)]
|
|
expires_in: Option<u64>,
|
|
#[serde(default)]
|
|
scope: Option<String>,
|
|
#[serde(default)]
|
|
token_type: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
struct OAuthErrorResponse {
|
|
error: String,
|
|
#[serde(default)]
|
|
error_description: Option<String>,
|
|
}
|
|
|
|
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
|
|
let expires_at = payload
|
|
.expires_in
|
|
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
|
|
.map(|seconds| Utc::now() + Duration::seconds(seconds));
|
|
|
|
OAuthToken {
|
|
access_token: payload.access_token,
|
|
refresh_token: payload.refresh_token,
|
|
expires_at,
|
|
scope: payload.scope,
|
|
token_type: payload.token_type,
|
|
}
|
|
}
|
|
|
|
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
|
|
if err.is_timeout() {
|
|
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
|
|
} else if err.is_connect() {
|
|
Error::Network(format!("OAuth {action} connection error: {err}"))
|
|
} else {
|
|
Error::Network(format!("OAuth {action} request failed: {err}"))
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use httpmock::prelude::*;
|
|
use serde_json::json;
|
|
|
|
fn config_for(server: &MockServer) -> McpOAuthConfig {
|
|
McpOAuthConfig {
|
|
client_id: "test-client".to_string(),
|
|
client_secret: None,
|
|
authorize_url: server.url("/authorize"),
|
|
token_url: server.url("/token"),
|
|
device_authorization_url: Some(server.url("/device")),
|
|
redirect_url: None,
|
|
scopes: vec!["repo".to_string(), "user".to_string()],
|
|
token_env: None,
|
|
header: None,
|
|
header_prefix: None,
|
|
}
|
|
}
|
|
|
|
fn sample_device_authorization() -> DeviceAuthorization {
|
|
DeviceAuthorization {
|
|
device_code: "device-123".to_string(),
|
|
user_code: "ABCD-EFGH".to_string(),
|
|
verification_uri: "https://example.test/activate".to_string(),
|
|
verification_uri_complete: Some(
|
|
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
|
|
),
|
|
expires_at: Utc::now() + Duration::minutes(10),
|
|
interval: StdDuration::from_secs(5),
|
|
message: Some("Open the verification URL and enter the code.".to_string()),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn start_device_authorization_returns_payload() {
|
|
let server = MockServer::start_async().await;
|
|
let device_mock = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/device");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"device_code": "device-123",
|
|
"user_code": "ABCD-EFGH",
|
|
"verification_uri": "https://example.test/activate",
|
|
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
|
|
"expires_in": 600,
|
|
"interval": 7,
|
|
"message": "Open the verification URL and enter the code."
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let client = OAuthClient::new(config_for(&server)).expect("client");
|
|
let auth = client
|
|
.start_device_authorization()
|
|
.await
|
|
.expect("device authorization payload");
|
|
|
|
assert_eq!(auth.user_code, "ABCD-EFGH");
|
|
assert_eq!(auth.interval, StdDuration::from_secs(7));
|
|
assert!(auth.expires_at > Utc::now());
|
|
device_mock.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn poll_device_token_reports_pending() {
|
|
let server = MockServer::start_async().await;
|
|
let pending = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST)
|
|
.path("/token")
|
|
.body_contains(
|
|
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
|
|
)
|
|
.body_contains("device_code=device-123");
|
|
then.status(400)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"error": "authorization_pending"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let config = config_for(&server);
|
|
let client = OAuthClient::new(config).expect("client");
|
|
let auth = sample_device_authorization();
|
|
|
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
|
match result {
|
|
DevicePollState::Pending { retry_in } => {
|
|
assert_eq!(retry_in, StdDuration::from_secs(5));
|
|
}
|
|
other => panic!("expected pending state, got {other:?}"),
|
|
}
|
|
|
|
pending.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn poll_device_token_applies_slow_down_backoff() {
|
|
let server = MockServer::start_async().await;
|
|
let slow = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/token");
|
|
then.status(400)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"error": "slow_down"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let config = config_for(&server);
|
|
let client = OAuthClient::new(config).expect("client");
|
|
let auth = sample_device_authorization();
|
|
|
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
|
match result {
|
|
DevicePollState::Pending { retry_in } => {
|
|
assert_eq!(retry_in, StdDuration::from_secs(10));
|
|
}
|
|
other => panic!("expected pending state, got {other:?}"),
|
|
}
|
|
|
|
slow.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn poll_device_token_returns_token_when_authorized() {
|
|
let server = MockServer::start_async().await;
|
|
let token = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST).path("/token");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"access_token": "token-abc",
|
|
"refresh_token": "refresh-xyz",
|
|
"expires_in": 3600,
|
|
"token_type": "Bearer",
|
|
"scope": "repo user"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let config = config_for(&server);
|
|
let client = OAuthClient::new(config).expect("client");
|
|
let auth = sample_device_authorization();
|
|
|
|
let result = client.poll_device_token(&auth).await.expect("poll result");
|
|
let token_info = match result {
|
|
DevicePollState::Complete(token) => token,
|
|
other => panic!("expected completion, got {other:?}"),
|
|
};
|
|
|
|
assert_eq!(token_info.access_token, "token-abc");
|
|
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
|
|
assert!(token_info.expires_at.is_some());
|
|
token.assert_async().await;
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn refresh_token_roundtrip() {
|
|
let server = MockServer::start_async().await;
|
|
let refresh = server
|
|
.mock_async(|when, then| {
|
|
when.method(POST)
|
|
.path("/token")
|
|
.body_contains("grant_type=refresh_token")
|
|
.body_contains("refresh_token=old-refresh");
|
|
then.status(200)
|
|
.header("content-type", "application/json")
|
|
.json_body(json!({
|
|
"access_token": "token-new",
|
|
"refresh_token": "refresh-new",
|
|
"expires_in": 1200,
|
|
"token_type": "Bearer"
|
|
}));
|
|
})
|
|
.await;
|
|
|
|
let config = config_for(&server);
|
|
let client = OAuthClient::new(config).expect("client");
|
|
let token = client
|
|
.refresh_token("old-refresh")
|
|
.await
|
|
.expect("refresh response");
|
|
|
|
assert_eq!(token.access_token, "token-new");
|
|
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
|
|
assert!(token.expires_at.is_some());
|
|
refresh.assert_async().await;
|
|
}
|
|
}
|