229 lines
7.9 KiB
Rust
229 lines
7.9 KiB
Rust
use anyhow::{Context, Result};
|
|
use sqlx::sqlite::SqliteConnection;
|
|
use std::collections::HashSet;
|
|
use std::path::PathBuf;
|
|
use std::time::{SystemTime, UNIX_EPOCH};
|
|
use tokio::fs;
|
|
|
|
pub struct Migration {
|
|
pub version: i64,
|
|
pub name: String,
|
|
pub sql_up: String,
|
|
#[allow(dead_code)]
|
|
pub sql_down: String,
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct Migrator {
|
|
migrations_dir: PathBuf,
|
|
}
|
|
|
|
impl Migrator {
|
|
pub fn new(migrations_dir: PathBuf) -> Result<Self> {
|
|
Ok(Migrator { migrations_dir })
|
|
}
|
|
|
|
async fn initialize(&self, conn: &mut SqliteConnection) -> Result<()> {
|
|
let mut tx = sqlx::Connection::begin(conn)
|
|
.await
|
|
.context("Failed to start transaction for initialization")?;
|
|
|
|
sqlx::query("CREATE TABLE IF NOT EXISTS migrations (version INTEGER PRIMARY KEY)")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.context("Failed to create migrations table")?;
|
|
|
|
let columns: HashSet<String> = {
|
|
let rows: Vec<(i32, String, String, i32, Option<String>, i32)> =
|
|
sqlx::query_as("PRAGMA table_info(migrations)")
|
|
.fetch_all(&mut *tx)
|
|
.await
|
|
.context("Failed to get migrations table info")?;
|
|
rows.into_iter().map(|row| row.1).collect()
|
|
};
|
|
|
|
if !columns.contains("name") {
|
|
sqlx::query("ALTER TABLE migrations ADD COLUMN name TEXT")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.context("Failed to add 'name' column to migrations table")?;
|
|
}
|
|
if !columns.contains("applied_at") {
|
|
sqlx::query("ALTER TABLE migrations ADD COLUMN applied_at INTEGER")
|
|
.execute(&mut *tx)
|
|
.await
|
|
.context("Failed to add 'applied_at' column to migrations table")?;
|
|
}
|
|
|
|
tx.commit()
|
|
.await
|
|
.context("Failed to commit migrations table initialization")?;
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn load_migrations_async(&self) -> Result<Vec<Migration>> {
|
|
let mut migrations = Vec::new();
|
|
|
|
// Use async-aware try_exists
|
|
if !fs::try_exists(&self.migrations_dir).await? {
|
|
return Ok(migrations);
|
|
}
|
|
|
|
let mut entries = fs::read_dir(&self.migrations_dir)
|
|
.await
|
|
.context("Failed to read migrations directory")?;
|
|
|
|
while let Some(entry) = entries.next_entry().await? {
|
|
let path = entry.path();
|
|
|
|
if path.is_file() && path.extension().unwrap_or_default() == "sql" {
|
|
let file_name = path.file_stem().unwrap().to_string_lossy();
|
|
|
|
// Format should be: VERSION_NAME.sql (e.g. 001_create_users.sql
|
|
if let Some((version_str, name)) = file_name.split_once('_') {
|
|
if let Ok(version) = version_str.parse::<i64>() {
|
|
let content = fs::read_to_string(&path).await.with_context(|| {
|
|
format!("Failed to read migration file: {}", path.display())
|
|
})?;
|
|
|
|
// Split content into up and down migrations if they exist
|
|
let parts: Vec<&str> = content.split("-- DOWN").collect();
|
|
let sql_up = parts[0].trim().to_string();
|
|
let sql_down = parts.get(1).map_or(String::new(), |s| s.trim().to_string());
|
|
|
|
migrations.push(Migration {
|
|
version,
|
|
name: name.to_string(),
|
|
sql_up,
|
|
sql_down,
|
|
});
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
migrations.sort_by_key(|m| m.version);
|
|
Ok(migrations)
|
|
}
|
|
|
|
pub async fn get_applied_migrations<'a, E>(&self, executor: E) -> Result<HashSet<i64>>
|
|
where
|
|
E: sqlx::Executor<'a, Database = sqlx::Sqlite>,
|
|
{
|
|
let versions =
|
|
sqlx::query_as::<_, (i64,)>("SELECT version FROM migrations ORDER BY version")
|
|
.fetch_all(executor)
|
|
.await
|
|
.context("Failed to get applied migrations")?
|
|
.into_iter()
|
|
.map(|row| row.0)
|
|
.collect();
|
|
Ok(versions)
|
|
}
|
|
|
|
pub async fn migrate_up_async(&self, conn: &mut SqliteConnection) -> Result<()> {
|
|
let migrations = self.load_migrations_async().await?;
|
|
|
|
self.initialize(conn).await?;
|
|
let applied = self.get_applied_migrations(&mut *conn).await?;
|
|
|
|
let mut tx = sqlx::Connection::begin(conn)
|
|
.await
|
|
.context("Failed to start transaction for migrations")?;
|
|
|
|
for migration in migrations {
|
|
if !applied.contains(&migration.version) {
|
|
println!(
|
|
"Applying migration {}: {}",
|
|
migration.version, migration.name
|
|
);
|
|
|
|
sqlx::query(&migration.sql_up)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to apply migration {}: {}",
|
|
migration.version, migration.name
|
|
)
|
|
})?;
|
|
|
|
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64;
|
|
sqlx::query("INSERT INTO migrations (version, name, applied_at) VALUES (?, ?, ?)")
|
|
.bind(migration.version)
|
|
.bind(&migration.name.clone())
|
|
.bind(now)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.with_context(|| format!("Failed to record migration {}", migration.version))?;
|
|
}
|
|
}
|
|
|
|
tx.commit().await.context("Failed to commit migrations")?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(dead_code)]
|
|
pub async fn migrate_down_async(
|
|
&self,
|
|
conn: &mut SqliteConnection,
|
|
target_version: Option<i64>,
|
|
) -> Result<()> {
|
|
let migrations = self.load_migrations_async().await?;
|
|
|
|
self.initialize(conn).await?;
|
|
let applied = self.get_applied_migrations(&mut *conn).await?;
|
|
|
|
// If no target specified, roll back only the latest migration
|
|
let max_applied = *applied.iter().max().unwrap_or(&0);
|
|
let target = target_version.unwrap_or(if max_applied > 0 { max_applied - 1 } else { 0 });
|
|
|
|
let mut tx = sqlx::Connection::begin(conn)
|
|
.await
|
|
.context("Failed to start transaction for migrations")?;
|
|
|
|
// Find migrations to roll back (in reverse order)
|
|
let mut to_rollback: Vec<&Migration> = migrations
|
|
.iter()
|
|
.filter(|m| applied.contains(&m.version) && m.version > target)
|
|
.collect();
|
|
|
|
to_rollback.sort_by_key(|m| std::cmp::Reverse(m.version));
|
|
|
|
for migration in to_rollback {
|
|
println!(
|
|
"Rolling back migration {}: {}",
|
|
migration.version, migration.name
|
|
);
|
|
|
|
if !migration.sql_down.is_empty() {
|
|
sqlx::query(&migration.sql_down)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.with_context(|| {
|
|
format!(
|
|
"Failed to rollback migration {}: {}",
|
|
migration.version, migration.name
|
|
)
|
|
})?;
|
|
} else {
|
|
println!("Warning: No down migration defined for {}", migration.name);
|
|
}
|
|
|
|
// Remove the migration record
|
|
sqlx::query("DELETE FROM migrations WHERE version = ?")
|
|
.bind(migration.version)
|
|
.execute(&mut *tx)
|
|
.await
|
|
.with_context(|| {
|
|
format!("Failed to remove migration record {}", migration.version)
|
|
})?;
|
|
}
|
|
|
|
tx.commit().await.context("Failed to commit rollback")?;
|
|
|
|
Ok(())
|
|
}
|
|
}
|