Files
owlen/crates/owlen-core/src/oauth.rs
vikingowl 690f5c7056 feat(cli): add MCP management subcommand with add/list/remove commands
Introduce `McpCommand` enum and handlers in `owlen-cli` to manage MCP server registrations, including adding, listing, and removing servers across configuration scopes. Add scoped configuration support (`ScopedMcpServer`, `McpConfigScope`) and OAuth token handling in core config, alongside runtime refresh of MCP servers. Implement toast notifications in the TUI (`render_toasts`, `Toast`, `ToastLevel`) and integrate async handling for session events. Update config loading, validation, and schema versioning to accommodate new MCP scopes and resources. Add `httpmock` as a dev dependency for testing.
2025-10-13 17:54:14 +02:00

508 lines
17 KiB
Rust

use std::time::Duration as StdDuration;
use chrono::{DateTime, Duration, Utc};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use crate::{Error, Result, config::McpOAuthConfig};
/// Persisted OAuth token set for MCP servers and providers.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)]
pub struct OAuthToken {
/// Bearer access token returned by the authorization server.
pub access_token: String,
/// Optional refresh token if the provider issues one.
#[serde(default)]
pub refresh_token: Option<String>,
/// Absolute UTC expiration timestamp for the access token.
#[serde(default)]
pub expires_at: Option<DateTime<Utc>>,
/// Optional space-delimited scope string supplied by the provider.
#[serde(default)]
pub scope: Option<String>,
/// Token type reported by the provider (typically `Bearer`).
#[serde(default)]
pub token_type: Option<String>,
}
impl OAuthToken {
/// Returns `true` if the access token has expired at the provided instant.
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
matches!(self.expires_at, Some(expiry) if now >= expiry)
}
/// Returns `true` if the token will expire within the supplied duration window.
pub fn will_expire_within(&self, window: Duration, now: DateTime<Utc>) -> bool {
matches!(self.expires_at, Some(expiry) if expiry - now <= window)
}
}
/// Active device-authorization session details returned by the authorization server.
#[derive(Debug, Clone)]
pub struct DeviceAuthorization {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_at: DateTime<Utc>,
pub interval: StdDuration,
pub message: Option<String>,
}
impl DeviceAuthorization {
pub fn is_expired(&self, now: DateTime<Utc>) -> bool {
now >= self.expires_at
}
}
/// Result of polling the token endpoint during a device-authorization flow.
#[derive(Debug, Clone)]
pub enum DevicePollState {
Pending { retry_in: StdDuration },
Complete(OAuthToken),
}
pub struct OAuthClient {
http: Client,
config: McpOAuthConfig,
}
impl OAuthClient {
pub fn new(config: McpOAuthConfig) -> Result<Self> {
let http = Client::builder()
.user_agent("OwlenOAuth/1.0")
.build()
.map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?;
Ok(Self { http, config })
}
fn scope_value(&self) -> Option<String> {
if self.config.scopes.is_empty() {
None
} else {
Some(self.config.scopes.join(" "))
}
}
fn token_request_base(&self) -> Vec<(String, String)> {
let mut params = vec![("client_id".to_string(), self.config.client_id.clone())];
if let Some(secret) = &self.config.client_secret {
params.push(("client_secret".to_string(), secret.clone()));
}
params
}
pub async fn start_device_authorization(&self) -> Result<DeviceAuthorization> {
let device_url = self
.config
.device_authorization_url
.as_ref()
.ok_or_else(|| {
Error::Config("Device authorization endpoint is not configured.".to_string())
})?;
let mut params = self.token_request_base();
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(device_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("start device authorization", err))?;
let status = response.status();
let payload = response
.json::<DeviceAuthorizationResponse>()
.await
.map_err(|err| {
Error::Auth(format!(
"Failed to parse device authorization response (status {status}): {err}"
))
})?;
let expires_at =
Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64);
let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1));
Ok(DeviceAuthorization {
device_code: payload.device_code,
user_code: payload.user_code,
verification_uri: payload.verification_uri,
verification_uri_complete: payload.verification_uri_complete,
expires_at,
interval,
message: payload.message,
})
}
pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result<DevicePollState> {
let mut params = self.token_request_base();
params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string()));
params.push(("device_code".to_string(), auth.device_code.clone()));
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(&self.config.token_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("poll device token", err))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|err| map_http_error("read token response", err))?;
if status.is_success() {
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
Error::Auth(format!(
"Failed to parse OAuth token response: {err}; body: {text}"
))
})?;
return Ok(DevicePollState::Complete(oauth_token_from_response(
payload,
)));
}
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
OAuthErrorResponse {
error: "unknown_error".to_string(),
error_description: Some(text.clone()),
}
});
match error.error.as_str() {
"authorization_pending" => Ok(DevicePollState::Pending {
retry_in: auth.interval,
}),
"slow_down" => Ok(DevicePollState::Pending {
retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)),
}),
"access_denied" => {
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
"User declined authorization".to_string()
})))
}
"expired_token" | "expired_device_code" => {
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
"Device authorization expired".to_string()
})))
}
other => Err(Error::Auth(
error
.error_description
.unwrap_or_else(|| format!("OAuth error: {other}")),
)),
}
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<OAuthToken> {
let mut params = self.token_request_base();
params.push(("grant_type".to_string(), "refresh_token".to_string()));
params.push(("refresh_token".to_string(), refresh_token.to_string()));
if let Some(scope) = self.scope_value() {
params.push(("scope".to_string(), scope));
}
let response = self
.http
.post(&self.config.token_url)
.form(&params)
.send()
.await
.map_err(|err| map_http_error("refresh OAuth token", err))?;
let status = response.status();
let text = response
.text()
.await
.map_err(|err| map_http_error("read refresh response", err))?;
if status.is_success() {
let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| {
Error::Auth(format!(
"Failed to parse OAuth refresh response: {err}; body: {text}"
))
})?;
Ok(oauth_token_from_response(payload))
} else {
let error = serde_json::from_str::<OAuthErrorResponse>(&text).unwrap_or_else(|_| {
OAuthErrorResponse {
error: "unknown_error".to_string(),
error_description: Some(text.clone()),
}
});
Err(Error::Auth(error.error_description.unwrap_or_else(|| {
format!("OAuth token refresh failed: {}", error.error)
})))
}
}
}
const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code";
#[derive(Debug, Deserialize)]
struct DeviceAuthorizationResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
#[serde(default)]
interval: Option<u64>,
#[serde(default)]
message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
expires_in: Option<u64>,
#[serde(default)]
scope: Option<String>,
#[serde(default)]
token_type: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
#[serde(default)]
error_description: Option<String>,
}
fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken {
let expires_at = payload
.expires_in
.map(|seconds| seconds.min(i64::MAX as u64) as i64)
.map(|seconds| Utc::now() + Duration::seconds(seconds));
OAuthToken {
access_token: payload.access_token,
refresh_token: payload.refresh_token,
expires_at,
scope: payload.scope,
token_type: payload.token_type,
}
}
fn map_http_error(action: &str, err: reqwest::Error) -> Error {
if err.is_timeout() {
Error::Timeout(format!("OAuth {action} request timed out: {err}"))
} else if err.is_connect() {
Error::Network(format!("OAuth {action} connection error: {err}"))
} else {
Error::Network(format!("OAuth {action} request failed: {err}"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use httpmock::prelude::*;
use serde_json::json;
fn config_for(server: &MockServer) -> McpOAuthConfig {
McpOAuthConfig {
client_id: "test-client".to_string(),
client_secret: None,
authorize_url: server.url("/authorize"),
token_url: server.url("/token"),
device_authorization_url: Some(server.url("/device")),
redirect_url: None,
scopes: vec!["repo".to_string(), "user".to_string()],
token_env: None,
header: None,
header_prefix: None,
}
}
fn sample_device_authorization() -> DeviceAuthorization {
DeviceAuthorization {
device_code: "device-123".to_string(),
user_code: "ABCD-EFGH".to_string(),
verification_uri: "https://example.test/activate".to_string(),
verification_uri_complete: Some(
"https://example.test/activate?user_code=ABCD-EFGH".to_string(),
),
expires_at: Utc::now() + Duration::minutes(10),
interval: StdDuration::from_secs(5),
message: Some("Open the verification URL and enter the code.".to_string()),
}
}
#[tokio::test]
async fn start_device_authorization_returns_payload() {
let server = MockServer::start_async().await;
let device_mock = server
.mock_async(|when, then| {
when.method(POST).path("/device");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"device_code": "device-123",
"user_code": "ABCD-EFGH",
"verification_uri": "https://example.test/activate",
"verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH",
"expires_in": 600,
"interval": 7,
"message": "Open the verification URL and enter the code."
}));
})
.await;
let client = OAuthClient::new(config_for(&server)).expect("client");
let auth = client
.start_device_authorization()
.await
.expect("device authorization payload");
assert_eq!(auth.user_code, "ABCD-EFGH");
assert_eq!(auth.interval, StdDuration::from_secs(7));
assert!(auth.expires_at > Utc::now());
device_mock.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_reports_pending() {
let server = MockServer::start_async().await;
let pending = server
.mock_async(|when, then| {
when.method(POST)
.path("/token")
.body_contains(
"grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code",
)
.body_contains("device_code=device-123");
then.status(400)
.header("content-type", "application/json")
.json_body(json!({
"error": "authorization_pending"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
match result {
DevicePollState::Pending { retry_in } => {
assert_eq!(retry_in, StdDuration::from_secs(5));
}
other => panic!("expected pending state, got {other:?}"),
}
pending.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_applies_slow_down_backoff() {
let server = MockServer::start_async().await;
let slow = server
.mock_async(|when, then| {
when.method(POST).path("/token");
then.status(400)
.header("content-type", "application/json")
.json_body(json!({
"error": "slow_down"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
match result {
DevicePollState::Pending { retry_in } => {
assert_eq!(retry_in, StdDuration::from_secs(10));
}
other => panic!("expected pending state, got {other:?}"),
}
slow.assert_async().await;
}
#[tokio::test]
async fn poll_device_token_returns_token_when_authorized() {
let server = MockServer::start_async().await;
let token = server
.mock_async(|when, then| {
when.method(POST).path("/token");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"access_token": "token-abc",
"refresh_token": "refresh-xyz",
"expires_in": 3600,
"token_type": "Bearer",
"scope": "repo user"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let auth = sample_device_authorization();
let result = client.poll_device_token(&auth).await.expect("poll result");
let token_info = match result {
DevicePollState::Complete(token) => token,
other => panic!("expected completion, got {other:?}"),
};
assert_eq!(token_info.access_token, "token-abc");
assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz"));
assert!(token_info.expires_at.is_some());
token.assert_async().await;
}
#[tokio::test]
async fn refresh_token_roundtrip() {
let server = MockServer::start_async().await;
let refresh = server
.mock_async(|when, then| {
when.method(POST)
.path("/token")
.body_contains("grant_type=refresh_token")
.body_contains("refresh_token=old-refresh");
then.status(200)
.header("content-type", "application/json")
.json_body(json!({
"access_token": "token-new",
"refresh_token": "refresh-new",
"expires_in": 1200,
"token_type": "Bearer"
}));
})
.await;
let config = config_for(&server);
let client = OAuthClient::new(config).expect("client");
let token = client
.refresh_token("old-refresh")
.await
.expect("refresh response");
assert_eq!(token.access_token, "token-new");
assert_eq!(token.refresh_token.as_deref(), Some("refresh-new"));
assert!(token.expires_at.is_some());
refresh.assert_async().await;
}
}