use std::time::Duration as StdDuration; use chrono::{DateTime, Duration, Utc}; use reqwest::Client; use serde::{Deserialize, Serialize}; use crate::{Error, Result, config::McpOAuthConfig}; /// Persisted OAuth token set for MCP servers and providers. #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Default)] pub struct OAuthToken { /// Bearer access token returned by the authorization server. pub access_token: String, /// Optional refresh token if the provider issues one. #[serde(default)] pub refresh_token: Option, /// Absolute UTC expiration timestamp for the access token. #[serde(default)] pub expires_at: Option>, /// Optional space-delimited scope string supplied by the provider. #[serde(default)] pub scope: Option, /// Token type reported by the provider (typically `Bearer`). #[serde(default)] pub token_type: Option, } impl OAuthToken { /// Returns `true` if the access token has expired at the provided instant. pub fn is_expired(&self, now: DateTime) -> bool { matches!(self.expires_at, Some(expiry) if now >= expiry) } /// Returns `true` if the token will expire within the supplied duration window. pub fn will_expire_within(&self, window: Duration, now: DateTime) -> bool { matches!(self.expires_at, Some(expiry) if expiry - now <= window) } } /// Active device-authorization session details returned by the authorization server. #[derive(Debug, Clone)] pub struct DeviceAuthorization { pub device_code: String, pub user_code: String, pub verification_uri: String, pub verification_uri_complete: Option, pub expires_at: DateTime, pub interval: StdDuration, pub message: Option, } impl DeviceAuthorization { pub fn is_expired(&self, now: DateTime) -> bool { now >= self.expires_at } } /// Result of polling the token endpoint during a device-authorization flow. #[derive(Debug, Clone)] pub enum DevicePollState { Pending { retry_in: StdDuration }, Complete(OAuthToken), } pub struct OAuthClient { http: Client, config: McpOAuthConfig, } impl OAuthClient { pub fn new(config: McpOAuthConfig) -> Result { let http = Client::builder() .user_agent("OwlenOAuth/1.0") .build() .map_err(|err| Error::Network(format!("Failed to construct HTTP client: {err}")))?; Ok(Self { http, config }) } fn scope_value(&self) -> Option { if self.config.scopes.is_empty() { None } else { Some(self.config.scopes.join(" ")) } } fn token_request_base(&self) -> Vec<(String, String)> { let mut params = vec![("client_id".to_string(), self.config.client_id.clone())]; if let Some(secret) = &self.config.client_secret { params.push(("client_secret".to_string(), secret.clone())); } params } pub async fn start_device_authorization(&self) -> Result { let device_url = self .config .device_authorization_url .as_ref() .ok_or_else(|| { Error::Config("Device authorization endpoint is not configured.".to_string()) })?; let mut params = self.token_request_base(); if let Some(scope) = self.scope_value() { params.push(("scope".to_string(), scope)); } let response = self .http .post(device_url) .form(¶ms) .send() .await .map_err(|err| map_http_error("start device authorization", err))?; let status = response.status(); let payload = response .json::() .await .map_err(|err| { Error::Auth(format!( "Failed to parse device authorization response (status {status}): {err}" )) })?; let expires_at = Utc::now() + Duration::seconds(payload.expires_in.min(i64::MAX as u64) as i64); let interval = StdDuration::from_secs(payload.interval.unwrap_or(5).max(1)); Ok(DeviceAuthorization { device_code: payload.device_code, user_code: payload.user_code, verification_uri: payload.verification_uri, verification_uri_complete: payload.verification_uri_complete, expires_at, interval, message: payload.message, }) } pub async fn poll_device_token(&self, auth: &DeviceAuthorization) -> Result { let mut params = self.token_request_base(); params.push(("grant_type".to_string(), DEVICE_CODE_GRANT.to_string())); params.push(("device_code".to_string(), auth.device_code.clone())); if let Some(scope) = self.scope_value() { params.push(("scope".to_string(), scope)); } let response = self .http .post(&self.config.token_url) .form(¶ms) .send() .await .map_err(|err| map_http_error("poll device token", err))?; let status = response.status(); let text = response .text() .await .map_err(|err| map_http_error("read token response", err))?; if status.is_success() { let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| { Error::Auth(format!( "Failed to parse OAuth token response: {err}; body: {text}" )) })?; return Ok(DevicePollState::Complete(oauth_token_from_response( payload, ))); } let error = serde_json::from_str::(&text).unwrap_or_else(|_| { OAuthErrorResponse { error: "unknown_error".to_string(), error_description: Some(text.clone()), } }); match error.error.as_str() { "authorization_pending" => Ok(DevicePollState::Pending { retry_in: auth.interval, }), "slow_down" => Ok(DevicePollState::Pending { retry_in: auth.interval.saturating_add(StdDuration::from_secs(5)), }), "access_denied" => { Err(Error::Auth(error.error_description.unwrap_or_else(|| { "User declined authorization".to_string() }))) } "expired_token" | "expired_device_code" => { Err(Error::Auth(error.error_description.unwrap_or_else(|| { "Device authorization expired".to_string() }))) } other => Err(Error::Auth( error .error_description .unwrap_or_else(|| format!("OAuth error: {other}")), )), } } pub async fn refresh_token(&self, refresh_token: &str) -> Result { let mut params = self.token_request_base(); params.push(("grant_type".to_string(), "refresh_token".to_string())); params.push(("refresh_token".to_string(), refresh_token.to_string())); if let Some(scope) = self.scope_value() { params.push(("scope".to_string(), scope)); } let response = self .http .post(&self.config.token_url) .form(¶ms) .send() .await .map_err(|err| map_http_error("refresh OAuth token", err))?; let status = response.status(); let text = response .text() .await .map_err(|err| map_http_error("read refresh response", err))?; if status.is_success() { let payload: TokenResponse = serde_json::from_str(&text).map_err(|err| { Error::Auth(format!( "Failed to parse OAuth refresh response: {err}; body: {text}" )) })?; Ok(oauth_token_from_response(payload)) } else { let error = serde_json::from_str::(&text).unwrap_or_else(|_| { OAuthErrorResponse { error: "unknown_error".to_string(), error_description: Some(text.clone()), } }); Err(Error::Auth(error.error_description.unwrap_or_else(|| { format!("OAuth token refresh failed: {}", error.error) }))) } } } const DEVICE_CODE_GRANT: &str = "urn:ietf:params:oauth:grant-type:device_code"; #[derive(Debug, Deserialize)] struct DeviceAuthorizationResponse { device_code: String, user_code: String, verification_uri: String, #[serde(default)] verification_uri_complete: Option, expires_in: u64, #[serde(default)] interval: Option, #[serde(default)] message: Option, } #[derive(Debug, Deserialize)] struct TokenResponse { access_token: String, #[serde(default)] refresh_token: Option, #[serde(default)] expires_in: Option, #[serde(default)] scope: Option, #[serde(default)] token_type: Option, } #[derive(Debug, Deserialize)] struct OAuthErrorResponse { error: String, #[serde(default)] error_description: Option, } fn oauth_token_from_response(payload: TokenResponse) -> OAuthToken { let expires_at = payload .expires_in .map(|seconds| seconds.min(i64::MAX as u64) as i64) .map(|seconds| Utc::now() + Duration::seconds(seconds)); OAuthToken { access_token: payload.access_token, refresh_token: payload.refresh_token, expires_at, scope: payload.scope, token_type: payload.token_type, } } fn map_http_error(action: &str, err: reqwest::Error) -> Error { if err.is_timeout() { Error::Timeout(format!("OAuth {action} request timed out: {err}")) } else if err.is_connect() { Error::Network(format!("OAuth {action} connection error: {err}")) } else { Error::Network(format!("OAuth {action} request failed: {err}")) } } #[cfg(test)] mod tests { use super::*; use httpmock::prelude::*; use serde_json::json; fn config_for(server: &MockServer) -> McpOAuthConfig { McpOAuthConfig { client_id: "test-client".to_string(), client_secret: None, authorize_url: server.url("/authorize"), token_url: server.url("/token"), device_authorization_url: Some(server.url("/device")), redirect_url: None, scopes: vec!["repo".to_string(), "user".to_string()], token_env: None, header: None, header_prefix: None, } } fn sample_device_authorization() -> DeviceAuthorization { DeviceAuthorization { device_code: "device-123".to_string(), user_code: "ABCD-EFGH".to_string(), verification_uri: "https://example.test/activate".to_string(), verification_uri_complete: Some( "https://example.test/activate?user_code=ABCD-EFGH".to_string(), ), expires_at: Utc::now() + Duration::minutes(10), interval: StdDuration::from_secs(5), message: Some("Open the verification URL and enter the code.".to_string()), } } #[tokio::test] async fn start_device_authorization_returns_payload() { let server = MockServer::start_async().await; let device_mock = server .mock_async(|when, then| { when.method(POST).path("/device"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "device_code": "device-123", "user_code": "ABCD-EFGH", "verification_uri": "https://example.test/activate", "verification_uri_complete": "https://example.test/activate?user_code=ABCD-EFGH", "expires_in": 600, "interval": 7, "message": "Open the verification URL and enter the code." })); }) .await; let client = OAuthClient::new(config_for(&server)).expect("client"); let auth = client .start_device_authorization() .await .expect("device authorization payload"); assert_eq!(auth.user_code, "ABCD-EFGH"); assert_eq!(auth.interval, StdDuration::from_secs(7)); assert!(auth.expires_at > Utc::now()); device_mock.assert_async().await; } #[tokio::test] async fn poll_device_token_reports_pending() { let server = MockServer::start_async().await; let pending = server .mock_async(|when, then| { when.method(POST) .path("/token") .body_contains( "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code", ) .body_contains("device_code=device-123"); then.status(400) .header("content-type", "application/json") .json_body(json!({ "error": "authorization_pending" })); }) .await; let config = config_for(&server); let client = OAuthClient::new(config).expect("client"); let auth = sample_device_authorization(); let result = client.poll_device_token(&auth).await.expect("poll result"); match result { DevicePollState::Pending { retry_in } => { assert_eq!(retry_in, StdDuration::from_secs(5)); } other => panic!("expected pending state, got {other:?}"), } pending.assert_async().await; } #[tokio::test] async fn poll_device_token_applies_slow_down_backoff() { let server = MockServer::start_async().await; let slow = server .mock_async(|when, then| { when.method(POST).path("/token"); then.status(400) .header("content-type", "application/json") .json_body(json!({ "error": "slow_down" })); }) .await; let config = config_for(&server); let client = OAuthClient::new(config).expect("client"); let auth = sample_device_authorization(); let result = client.poll_device_token(&auth).await.expect("poll result"); match result { DevicePollState::Pending { retry_in } => { assert_eq!(retry_in, StdDuration::from_secs(10)); } other => panic!("expected pending state, got {other:?}"), } slow.assert_async().await; } #[tokio::test] async fn poll_device_token_returns_token_when_authorized() { let server = MockServer::start_async().await; let token = server .mock_async(|when, then| { when.method(POST).path("/token"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "access_token": "token-abc", "refresh_token": "refresh-xyz", "expires_in": 3600, "token_type": "Bearer", "scope": "repo user" })); }) .await; let config = config_for(&server); let client = OAuthClient::new(config).expect("client"); let auth = sample_device_authorization(); let result = client.poll_device_token(&auth).await.expect("poll result"); let token_info = match result { DevicePollState::Complete(token) => token, other => panic!("expected completion, got {other:?}"), }; assert_eq!(token_info.access_token, "token-abc"); assert_eq!(token_info.refresh_token.as_deref(), Some("refresh-xyz")); assert!(token_info.expires_at.is_some()); token.assert_async().await; } #[tokio::test] async fn refresh_token_roundtrip() { let server = MockServer::start_async().await; let refresh = server .mock_async(|when, then| { when.method(POST) .path("/token") .body_contains("grant_type=refresh_token") .body_contains("refresh_token=old-refresh"); then.status(200) .header("content-type", "application/json") .json_body(json!({ "access_token": "token-new", "refresh_token": "refresh-new", "expires_in": 1200, "token_type": "Bearer" })); }) .await; let config = config_for(&server); let client = OAuthClient::new(config).expect("client"); let token = client .refresh_token("old-refresh") .await .expect("refresh response"); assert_eq!(token.access_token, "token-new"); assert_eq!(token.refresh_token.as_deref(), Some("refresh-new")); assert!(token.expires_at.is_some()); refresh.assert_async().await; } }