- Introduce `MockProvider` with configurable models, health status, generation handlers, and error simulation. - Add common test utilities and integration tests covering provider registration, model aggregation, request routing, error handling, and health refresh.
107 lines
3.0 KiB
Rust
107 lines
3.0 KiB
Rust
use std::sync::Arc;
|
|
|
|
use async_trait::async_trait;
|
|
use futures::stream::{self, StreamExt};
|
|
use owlen_core::Result as CoreResult;
|
|
use owlen_core::provider::{
|
|
GenerateChunk, GenerateRequest, GenerateStream, ModelInfo, ModelProvider, ProviderMetadata,
|
|
ProviderStatus, ProviderType,
|
|
};
|
|
|
|
pub struct MockProvider {
|
|
metadata: ProviderMetadata,
|
|
models: Vec<ModelInfo>,
|
|
status: ProviderStatus,
|
|
#[allow(clippy::type_complexity)]
|
|
generate_handler: Option<Arc<dyn Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync>>,
|
|
generate_error: Option<Arc<dyn Fn() -> owlen_core::Error + Send + Sync>>,
|
|
}
|
|
|
|
impl MockProvider {
|
|
pub fn new(id: &str) -> Self {
|
|
let metadata = ProviderMetadata::new(
|
|
id,
|
|
format!("Mock Provider ({})", id),
|
|
ProviderType::Local,
|
|
false,
|
|
);
|
|
|
|
Self {
|
|
metadata,
|
|
models: vec![ModelInfo {
|
|
name: format!("{}-primary", id),
|
|
size_bytes: None,
|
|
capabilities: vec!["chat".into()],
|
|
description: Some("Mock model".into()),
|
|
provider: ProviderMetadata::new(id, "Mock", ProviderType::Local, false),
|
|
metadata: Default::default(),
|
|
}],
|
|
status: ProviderStatus::Available,
|
|
generate_handler: None,
|
|
generate_error: None,
|
|
}
|
|
}
|
|
|
|
pub fn with_models(mut self, models: Vec<ModelInfo>) -> Self {
|
|
self.models = models;
|
|
self
|
|
}
|
|
|
|
pub fn with_status(mut self, status: ProviderStatus) -> Self {
|
|
self.status = status;
|
|
self
|
|
}
|
|
|
|
pub fn with_generate_handler<F>(mut self, handler: F) -> Self
|
|
where
|
|
F: Fn(GenerateRequest) -> Vec<GenerateChunk> + Send + Sync + 'static,
|
|
{
|
|
self.generate_handler = Some(Arc::new(handler));
|
|
self
|
|
}
|
|
|
|
pub fn with_generate_error<F>(mut self, factory: F) -> Self
|
|
where
|
|
F: Fn() -> owlen_core::Error + Send + Sync + 'static,
|
|
{
|
|
self.generate_error = Some(Arc::new(factory));
|
|
self
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ModelProvider for MockProvider {
|
|
fn metadata(&self) -> &ProviderMetadata {
|
|
&self.metadata
|
|
}
|
|
|
|
async fn health_check(&self) -> CoreResult<ProviderStatus> {
|
|
Ok(self.status)
|
|
}
|
|
|
|
async fn list_models(&self) -> CoreResult<Vec<ModelInfo>> {
|
|
Ok(self.models.clone())
|
|
}
|
|
|
|
async fn generate_stream(&self, request: GenerateRequest) -> CoreResult<GenerateStream> {
|
|
if let Some(factory) = &self.generate_error {
|
|
return Err(factory());
|
|
}
|
|
|
|
let chunks = if let Some(handler) = &self.generate_handler {
|
|
(handler)(request)
|
|
} else {
|
|
vec![GenerateChunk::final_chunk()]
|
|
};
|
|
|
|
let stream = stream::iter(chunks.into_iter().map(Ok)).boxed();
|
|
Ok(Box::pin(stream))
|
|
}
|
|
}
|
|
|
|
impl From<MockProvider> for Arc<dyn ModelProvider> {
|
|
fn from(provider: MockProvider) -> Self {
|
|
Arc::new(provider)
|
|
}
|
|
}
|