use anyhow::{Context, Result}; use sqlx::sqlite::SqliteConnection; use std::collections::HashSet; use std::path::PathBuf; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::fs; pub struct Migration { pub version: i64, pub name: String, pub sql_up: String, #[allow(dead_code)] pub sql_down: String, } #[derive(Clone)] pub struct Migrator { migrations_dir: PathBuf, } impl Migrator { pub fn new(migrations_dir: PathBuf) -> Result { Ok(Migrator { migrations_dir }) } async fn initialize(&self, conn: &mut SqliteConnection) -> Result<()> { let mut tx = sqlx::Connection::begin(conn) .await .context("Failed to start transaction for initialization")?; sqlx::query("CREATE TABLE IF NOT EXISTS migrations (version INTEGER PRIMARY KEY)") .execute(&mut *tx) .await .context("Failed to create migrations table")?; let columns: HashSet = { let rows: Vec<(i32, String, String, i32, Option, i32)> = sqlx::query_as("PRAGMA table_info(migrations)") .fetch_all(&mut *tx) .await .context("Failed to get migrations table info")?; rows.into_iter().map(|row| row.1).collect() }; if !columns.contains("name") { sqlx::query("ALTER TABLE migrations ADD COLUMN name TEXT") .execute(&mut *tx) .await .context("Failed to add 'name' column to migrations table")?; } if !columns.contains("applied_at") { sqlx::query("ALTER TABLE migrations ADD COLUMN applied_at INTEGER") .execute(&mut *tx) .await .context("Failed to add 'applied_at' column to migrations table")?; } tx.commit() .await .context("Failed to commit migrations table initialization")?; Ok(()) } pub async fn load_migrations_async(&self) -> Result> { let mut migrations = Vec::new(); // Use async-aware try_exists if !fs::try_exists(&self.migrations_dir).await? { return Ok(migrations); } let mut entries = fs::read_dir(&self.migrations_dir) .await .context("Failed to read migrations directory")?; while let Some(entry) = entries.next_entry().await? { let path = entry.path(); if path.is_file() && path.extension().unwrap_or_default() == "sql" { let file_name = path.file_stem().unwrap().to_string_lossy(); // Format should be: VERSION_NAME.sql (e.g. 001_create_users.sql if let Some((version_str, name)) = file_name.split_once('_') { if let Ok(version) = version_str.parse::() { let content = fs::read_to_string(&path).await.with_context(|| { format!("Failed to read migration file: {}", path.display()) })?; // Split content into up and down migrations if they exist let parts: Vec<&str> = content.split("-- DOWN").collect(); let sql_up = parts[0].trim().to_string(); let sql_down = parts.get(1).map_or(String::new(), |s| s.trim().to_string()); migrations.push(Migration { version, name: name.to_string(), sql_up, sql_down, }); } } } } migrations.sort_by_key(|m| m.version); Ok(migrations) } pub async fn get_applied_migrations<'a, E>(&self, executor: E) -> Result> where E: sqlx::Executor<'a, Database = sqlx::Sqlite>, { let versions = sqlx::query_as::<_, (i64,)>("SELECT version FROM migrations ORDER BY version") .fetch_all(executor) .await .context("Failed to get applied migrations")? .into_iter() .map(|row| row.0) .collect(); Ok(versions) } pub async fn migrate_up_async(&self, conn: &mut SqliteConnection) -> Result<()> { let migrations = self.load_migrations_async().await?; self.initialize(conn).await?; let applied = self.get_applied_migrations(&mut *conn).await?; let mut tx = sqlx::Connection::begin(conn) .await .context("Failed to start transaction for migrations")?; for migration in migrations { if !applied.contains(&migration.version) { println!( "Applying migration {}: {}", migration.version, migration.name ); sqlx::query(&migration.sql_up) .execute(&mut *tx) .await .with_context(|| { format!( "Failed to apply migration {}: {}", migration.version, migration.name ) })?; let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64; sqlx::query("INSERT INTO migrations (version, name, applied_at) VALUES (?, ?, ?)") .bind(migration.version) .bind(&migration.name.clone()) .bind(now) .execute(&mut *tx) .await .with_context(|| format!("Failed to record migration {}", migration.version))?; } } tx.commit().await.context("Failed to commit migrations")?; Ok(()) } #[allow(dead_code)] pub async fn migrate_down_async( &self, conn: &mut SqliteConnection, target_version: Option, ) -> Result<()> { let migrations = self.load_migrations_async().await?; self.initialize(conn).await?; let applied = self.get_applied_migrations(&mut *conn).await?; // If no target specified, roll back only the latest migration let max_applied = *applied.iter().max().unwrap_or(&0); let target = target_version.unwrap_or(if max_applied > 0 { max_applied - 1 } else { 0 }); let mut tx = sqlx::Connection::begin(conn) .await .context("Failed to start transaction for migrations")?; // Find migrations to roll back (in reverse order) let mut to_rollback: Vec<&Migration> = migrations .iter() .filter(|m| applied.contains(&m.version) && m.version > target) .collect(); to_rollback.sort_by_key(|m| std::cmp::Reverse(m.version)); for migration in to_rollback { println!( "Rolling back migration {}: {}", migration.version, migration.name ); if !migration.sql_down.is_empty() { sqlx::query(&migration.sql_down) .execute(&mut *tx) .await .with_context(|| { format!( "Failed to rollback migration {}: {}", migration.version, migration.name ) })?; } else { println!("Warning: No down migration defined for {}", migration.name); } // Remove the migration record sqlx::query("DELETE FROM migrations WHERE version = ?") .bind(migration.version) .execute(&mut *tx) .await .with_context(|| { format!("Failed to remove migration record {}", migration.version) })?; } tx.commit().await.context("Failed to commit rollback")?; Ok(()) } }