Replace migrate_to_mariadb.py with sbin/setup_mariadb.py

Covers all 5 models/tables, prompts for credentials, snapshots SQLite,
and prints env/service config at the end.
This commit is contained in:
2026-04-15 17:34:45 -05:00
parent 8721254476
commit f64e14e713
3 changed files with 324 additions and 145 deletions
+296
View File
@@ -0,0 +1,296 @@
#!/usr/bin/env python3
"""
sbin/setup_mariadb.py
=====================
Interactive setup and migration script: creates all Battery Tracker tables on
MariaDB and populates them from the existing SQLite database.
Run from the repository root:
python sbin/setup_mariadb.py
Credentials are read from environment variables; missing ones are prompted:
MARIADB_HOST (default: localhost)
MARIADB_PORT (default: 3306)
MARIADB_USER
MARIADB_PASSWORD
MARIADB_DATABASE
Or supply a full URL via MARIADB_URL to skip individual prompts entirely:
MARIADB_URL='mysql+pymysql://user:pass@host/db?charset=utf8mb4' \\
python sbin/setup_mariadb.py
"""
import getpass
import os
import shutil
import sys
from datetime import date
from pathlib import Path
# Allow imports from the repo root regardless of where this script is invoked.
REPO_ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO_ROOT))
from sqlalchemy import create_engine, text
from sqlalchemy.orm import sessionmaker
import config
from models import Base, Battery, BatteryPctLog, CapacityTest, ChargeLog, Device
# ---------------------------------------------------------------------------
# Credential helpers
# ---------------------------------------------------------------------------
def _env_or_prompt(var: str, label: str, default: str | None = None) -> str:
val = os.environ.get(var, "").strip()
if val:
return val
suffix = f" [{default}]" if default else ""
answer = input(f" {label}{suffix}: ").strip()
return answer or (default or "")
def _env_or_prompt_secret(var: str, label: str) -> str:
val = os.environ.get(var, "").strip()
if val:
return val
return getpass.getpass(f" {label}: ")
def collect_credentials() -> str:
"""Return a MariaDB SQLAlchemy URL, prompting for any missing pieces."""
url = os.environ.get("MARIADB_URL", "").strip()
if url:
print(f" Using MARIADB_URL from environment.")
return url
print("Enter MariaDB connection details (press Enter to accept defaults):\n")
host = _env_or_prompt("MARIADB_HOST", "Host", "localhost")
port = _env_or_prompt("MARIADB_PORT", "Port", "3306")
user = _env_or_prompt("MARIADB_USER", "User")
password = _env_or_prompt_secret("MARIADB_PASSWORD", "Password")
database = _env_or_prompt("MARIADB_DATABASE", "Database")
if not user or not database:
print("\nERROR: MARIADB_USER and MARIADB_DATABASE are required.")
sys.exit(1)
return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4"
# ---------------------------------------------------------------------------
# Snapshot
# ---------------------------------------------------------------------------
def snapshot_sqlite() -> None:
"""Copy batteries.db → batteries.db.YYYY-MM-DD.snapshot (in repo root)."""
src = REPO_ROOT / "batteries.db"
if not src.exists():
print(" No batteries.db found — skipping snapshot (fresh install).")
return
dst = REPO_ROOT / f"batteries.db.{date.today().isoformat()}.snapshot"
shutil.copy2(src, dst)
print(f" Snapshot written: {dst.name}")
# ---------------------------------------------------------------------------
# Migration
# ---------------------------------------------------------------------------
def migrate(mariadb_url: str) -> None:
print("\n=== Battery Tracker: SQLite → MariaDB Migration ===\n")
# -- Engines --
sqlite_engine = create_engine(config.SQLALCHEMY_DATABASE_URI)
mariadb_engine = create_engine(mariadb_url, pool_pre_ping=True)
SrcSession = sessionmaker(bind=sqlite_engine)
DstSession = sessionmaker(bind=mariadb_engine)
src = SrcSession()
dst = DstSession()
try:
# -- Create tables --
print("Creating tables on MariaDB (if not exist)…")
Base.metadata.create_all(mariadb_engine)
# -- Read source data --
src_devices = src.query(Device).all()
src_batteries = src.query(Battery).all()
src_cap_tests = src.query(CapacityTest).all()
src_charge_logs = src.query(ChargeLog).all()
src_pct_logs = src.query(BatteryPctLog).all()
print(
f"Source: {len(src_devices)} devices, {len(src_batteries)} batteries, "
f"{len(src_cap_tests)} capacity tests, {len(src_charge_logs)} charge logs, "
f"{len(src_pct_logs)} pct logs\n"
)
# -- Devices (no FK dependencies) --
print("Migrating devices…")
for d in src_devices:
dst.add(Device(
id=d.id,
name=d.name,
battery_slots=d.battery_slots,
device_type=d.device_type,
notes=d.notes,
ha_entity_id=d.ha_entity_id,
))
dst.flush()
# -- Batteries (FK → device) --
print("Migrating batteries…")
for b in src_batteries:
dst.add(Battery(
id=b.id,
label=b.label,
brand=b.brand,
status=b.status,
device_id=b.device_id,
notes=b.notes,
size=b.size,
chemistry=b.chemistry,
capacity_mah=b.capacity_mah,
tested_capacity_mah=b.tested_capacity_mah,
tested_date=b.tested_date,
charge_cycles=b.charge_cycles,
purchase_date=b.purchase_date,
storage_location=b.storage_location,
battery_percentage=b.battery_percentage,
))
dst.flush()
# -- CapacityTest (FK → battery) --
print("Migrating capacity tests…")
for ct in src_cap_tests:
dst.add(CapacityTest(
id=ct.id,
battery_id=ct.battery_id,
tested_capacity_mah=ct.tested_capacity_mah,
tested_date=ct.tested_date,
notes=ct.notes,
))
dst.flush()
# -- ChargeLog (FK → battery) --
print("Migrating charge logs…")
for cl in src_charge_logs:
dst.add(ChargeLog(
id=cl.id,
battery_id=cl.battery_id,
charged_date=cl.charged_date,
increment_cycles=cl.increment_cycles,
notes=cl.notes,
))
dst.flush()
# -- BatteryPctLog (FK → battery) --
print("Migrating battery pct logs…")
for pl in src_pct_logs:
dst.add(BatteryPctLog(
id=pl.id,
battery_id=pl.battery_id,
percentage=pl.percentage,
recorded_at=pl.recorded_at,
source=pl.source,
))
dst.flush()
dst.commit()
print("Commit successful.\n")
# -- Reset AUTO_INCREMENT --
table_max = {
"device": max((d.id for d in src_devices), default=0),
"battery": max((b.id for b in src_batteries), default=0),
"capacity_test": max((c.id for c in src_cap_tests), default=0),
"charge_log": max((c.id for c in src_charge_logs), default=0),
"battery_pct_log": max((p.id for p in src_pct_logs), default=0),
}
with mariadb_engine.connect() as conn:
for table, max_id in table_max.items():
conn.execute(
text(f"ALTER TABLE {table} AUTO_INCREMENT = :v"),
{"v": max_id + 1},
)
conn.commit()
print("AUTO_INCREMENT counters reset.\n")
# -- Verify counts --
counts = {
"device": (len(src_devices), dst.query(Device).count()),
"battery": (len(src_batteries), dst.query(Battery).count()),
"capacity_test": (len(src_cap_tests), dst.query(CapacityTest).count()),
"charge_log": (len(src_charge_logs), dst.query(ChargeLog).count()),
"battery_pct_log": (len(src_pct_logs), dst.query(BatteryPctLog).count()),
}
print("=== Verification ===")
print(f"{'Table':<20} {'SQLite':>8} {'MariaDB':>9} {'OK?':>6}")
print("-" * 47)
all_ok = True
for table, (src_n, dst_n) in counts.items():
ok = src_n == dst_n
all_ok = all_ok and ok
print(f"{table:<20} {src_n:>8} {dst_n:>9} {'OK' if ok else 'MISMATCH':>6}")
if not all_ok:
print("\nERROR: Row count mismatch. Do not decommission SQLite.")
sys.exit(1)
print("\nMigration complete. All row counts match.")
except Exception as exc:
dst.rollback()
print(f"\nERROR: {exc}")
raise
finally:
src.close()
dst.close()
# ---------------------------------------------------------------------------
# Post-migration instructions
# ---------------------------------------------------------------------------
def print_next_steps(mariadb_url: str) -> None:
service_file = Path.home() / ".config/systemd/user/battery-tracker.service"
print("""
=== Environment configuration ===
Add to your .env file (or export in shell before starting the app):
""")
print(f" DATABASE_URL={mariadb_url}")
print(f"""
=== systemd service file change ===
Edit: {service_file}
In the [Service] section, add (or replace any existing DATABASE_URL line):
Environment=DATABASE_URL={mariadb_url}
Then reload and restart:
systemctl --user daemon-reload
systemctl --user restart battery-tracker
""")
# ---------------------------------------------------------------------------
# Entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
print("=== Battery Tracker — MariaDB Setup ===\n")
mariadb_url = collect_credentials()
print("\nSnapshotting SQLite database…")
snapshot_sqlite()
migrate(mariadb_url)
print_next_steps(mariadb_url)