Files
owly-news/backend-rust/src/migrations.rs

229 lines
7.9 KiB
Rust

use anyhow::{Context, Result};
use sqlx::sqlite::SqliteConnection;
use std::collections::HashSet;
use std::path::PathBuf;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::fs;
pub struct Migration {
pub version: i64,
pub name: String,
pub sql_up: String,
#[allow(dead_code)]
pub sql_down: String,
}
#[derive(Clone)]
pub struct Migrator {
migrations_dir: PathBuf,
}
impl Migrator {
pub fn new(migrations_dir: PathBuf) -> Result<Self> {
Ok(Migrator { migrations_dir })
}
async fn initialize(&self, conn: &mut SqliteConnection) -> Result<()> {
let mut tx = sqlx::Connection::begin(conn)
.await
.context("Failed to start transaction for initialization")?;
sqlx::query("CREATE TABLE IF NOT EXISTS migrations (version INTEGER PRIMARY KEY)")
.execute(&mut *tx)
.await
.context("Failed to create migrations table")?;
let columns: HashSet<String> = {
let rows: Vec<(i32, String, String, i32, Option<String>, i32)> =
sqlx::query_as("PRAGMA table_info(migrations)")
.fetch_all(&mut *tx)
.await
.context("Failed to get migrations table info")?;
rows.into_iter().map(|row| row.1).collect()
};
if !columns.contains("name") {
sqlx::query("ALTER TABLE migrations ADD COLUMN name TEXT")
.execute(&mut *tx)
.await
.context("Failed to add 'name' column to migrations table")?;
}
if !columns.contains("applied_at") {
sqlx::query("ALTER TABLE migrations ADD COLUMN applied_at INTEGER")
.execute(&mut *tx)
.await
.context("Failed to add 'applied_at' column to migrations table")?;
}
tx.commit()
.await
.context("Failed to commit migrations table initialization")?;
Ok(())
}
pub async fn load_migrations_async(&self) -> Result<Vec<Migration>> {
let mut migrations = Vec::new();
// Use async-aware try_exists
if !fs::try_exists(&self.migrations_dir).await? {
return Ok(migrations);
}
let mut entries = fs::read_dir(&self.migrations_dir)
.await
.context("Failed to read migrations directory")?;
while let Some(entry) = entries.next_entry().await? {
let path = entry.path();
if path.is_file() && path.extension().unwrap_or_default() == "sql" {
let file_name = path.file_stem().unwrap().to_string_lossy();
// Format should be: VERSION_NAME.sql (e.g. 001_create_users.sql
if let Some((version_str, name)) = file_name.split_once('_') {
if let Ok(version) = version_str.parse::<i64>() {
let content = fs::read_to_string(&path).await.with_context(|| {
format!("Failed to read migration file: {}", path.display())
})?;
// Split content into up and down migrations if they exist
let parts: Vec<&str> = content.split("-- DOWN").collect();
let sql_up = parts[0].trim().to_string();
let sql_down = parts.get(1).map_or(String::new(), |s| s.trim().to_string());
migrations.push(Migration {
version,
name: name.to_string(),
sql_up,
sql_down,
});
}
}
}
}
migrations.sort_by_key(|m| m.version);
Ok(migrations)
}
pub async fn get_applied_migrations<'a, E>(&self, executor: E) -> Result<HashSet<i64>>
where
E: sqlx::Executor<'a, Database = sqlx::Sqlite>,
{
let versions =
sqlx::query_as::<_, (i64,)>("SELECT version FROM migrations ORDER BY version")
.fetch_all(executor)
.await
.context("Failed to get applied migrations")?
.into_iter()
.map(|row| row.0)
.collect();
Ok(versions)
}
pub async fn migrate_up_async(&self, conn: &mut SqliteConnection) -> Result<()> {
let migrations = self.load_migrations_async().await?;
self.initialize(conn).await?;
let applied = self.get_applied_migrations(&mut *conn).await?;
let mut tx = sqlx::Connection::begin(conn)
.await
.context("Failed to start transaction for migrations")?;
for migration in migrations {
if !applied.contains(&migration.version) {
println!(
"Applying migration {}: {}",
migration.version, migration.name
);
sqlx::query(&migration.sql_up)
.execute(&mut *tx)
.await
.with_context(|| {
format!(
"Failed to apply migration {}: {}",
migration.version, migration.name
)
})?;
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs() as i64;
sqlx::query("INSERT INTO migrations (version, name, applied_at) VALUES (?, ?, ?)")
.bind(migration.version)
.bind(&migration.name.clone())
.bind(now)
.execute(&mut *tx)
.await
.with_context(|| format!("Failed to record migration {}", migration.version))?;
}
}
tx.commit().await.context("Failed to commit migrations")?;
Ok(())
}
#[allow(dead_code)]
pub async fn migrate_down_async(
&self,
conn: &mut SqliteConnection,
target_version: Option<i64>,
) -> Result<()> {
let migrations = self.load_migrations_async().await?;
self.initialize(conn).await?;
let applied = self.get_applied_migrations(&mut *conn).await?;
// If no target specified, roll back only the latest migration
let max_applied = *applied.iter().max().unwrap_or(&0);
let target = target_version.unwrap_or(if max_applied > 0 { max_applied - 1 } else { 0 });
let mut tx = sqlx::Connection::begin(conn)
.await
.context("Failed to start transaction for migrations")?;
// Find migrations to roll back (in reverse order)
let mut to_rollback: Vec<&Migration> = migrations
.iter()
.filter(|m| applied.contains(&m.version) && m.version > target)
.collect();
to_rollback.sort_by_key(|m| std::cmp::Reverse(m.version));
for migration in to_rollback {
println!(
"Rolling back migration {}: {}",
migration.version, migration.name
);
if !migration.sql_down.is_empty() {
sqlx::query(&migration.sql_down)
.execute(&mut *tx)
.await
.with_context(|| {
format!(
"Failed to rollback migration {}: {}",
migration.version, migration.name
)
})?;
} else {
println!("Warning: No down migration defined for {}", migration.name);
}
// Remove the migration record
sqlx::query("DELETE FROM migrations WHERE version = ?")
.bind(migration.version)
.execute(&mut *tx)
.await
.with_context(|| {
format!("Failed to remove migration record {}", migration.version)
})?;
}
tx.commit().await.context("Failed to commit rollback")?;
Ok(())
}
}