#!/usr/bin/env python3 """ sbin/setup_mariadb.py ===================== Interactive setup and migration script: creates all Battery Tracker tables on MariaDB and populates them from the existing SQLite database. Run from the repository root: python sbin/setup_mariadb.py Credentials are read from environment variables; missing ones are prompted: MARIADB_HOST (default: localhost) MARIADB_PORT (default: 3306) MARIADB_USER MARIADB_PASSWORD MARIADB_DATABASE Or supply a full URL via MARIADB_URL to skip individual prompts entirely: MARIADB_URL='mysql+pymysql://user:pass@host/db?charset=utf8mb4' \\ python sbin/setup_mariadb.py """ import getpass import os import shutil import sys from datetime import date from pathlib import Path # Allow imports from the repo root regardless of where this script is invoked. REPO_ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO_ROOT)) from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker import config from models import Base, Battery, BatteryPctLog, CapacityTest, ChargeLog, Device # --------------------------------------------------------------------------- # Credential helpers # --------------------------------------------------------------------------- def _env_or_prompt(var: str, label: str, default: str | None = None) -> str: val = os.environ.get(var, "").strip() if val: return val suffix = f" [{default}]" if default else "" answer = input(f" {label}{suffix}: ").strip() return answer or (default or "") def _env_or_prompt_secret(var: str, label: str) -> str: val = os.environ.get(var, "").strip() if val: return val return getpass.getpass(f" {label}: ") def collect_credentials() -> str: """Return a MariaDB SQLAlchemy URL, prompting for any missing pieces.""" url = os.environ.get("MARIADB_URL", "").strip() if url: print(f" Using MARIADB_URL from environment.") return url print("Enter MariaDB connection details (press Enter to accept defaults):\n") host = _env_or_prompt("MARIADB_HOST", "Host", "localhost") port = _env_or_prompt("MARIADB_PORT", "Port", "3306") user = _env_or_prompt("MARIADB_USER", "User") password = _env_or_prompt_secret("MARIADB_PASSWORD", "Password") database = _env_or_prompt("MARIADB_DATABASE", "Database") if not user or not database: print("\nERROR: MARIADB_USER and MARIADB_DATABASE are required.") sys.exit(1) return f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset=utf8mb4" # --------------------------------------------------------------------------- # Snapshot # --------------------------------------------------------------------------- def snapshot_sqlite() -> None: """Copy batteries.db → batteries.db.YYYY-MM-DD.snapshot (in repo root).""" src = REPO_ROOT / "batteries.db" if not src.exists(): print(" No batteries.db found — skipping snapshot (fresh install).") return dst = REPO_ROOT / f"batteries.db.{date.today().isoformat()}.snapshot" shutil.copy2(src, dst) print(f" Snapshot written: {dst.name}") # --------------------------------------------------------------------------- # Migration # --------------------------------------------------------------------------- def migrate(mariadb_url: str) -> None: print("\n=== Battery Tracker: SQLite → MariaDB Migration ===\n") # -- Engines -- sqlite_engine = create_engine(config.SQLALCHEMY_DATABASE_URI) mariadb_engine = create_engine(mariadb_url, pool_pre_ping=True) SrcSession = sessionmaker(bind=sqlite_engine) DstSession = sessionmaker(bind=mariadb_engine) src = SrcSession() dst = DstSession() try: # -- Create tables -- print("Creating tables on MariaDB (if not exist)…") Base.metadata.create_all(mariadb_engine) # -- Read source data -- src_devices = src.query(Device).all() src_batteries = src.query(Battery).all() src_cap_tests = src.query(CapacityTest).all() src_charge_logs = src.query(ChargeLog).all() src_pct_logs = src.query(BatteryPctLog).all() print( f"Source: {len(src_devices)} devices, {len(src_batteries)} batteries, " f"{len(src_cap_tests)} capacity tests, {len(src_charge_logs)} charge logs, " f"{len(src_pct_logs)} pct logs\n" ) # -- Devices (no FK dependencies) -- print("Migrating devices…") for d in src_devices: dst.add(Device( id=d.id, name=d.name, battery_slots=d.battery_slots, device_type=d.device_type, notes=d.notes, ha_entity_id=d.ha_entity_id, )) dst.flush() # -- Batteries (FK → device) -- print("Migrating batteries…") for b in src_batteries: dst.add(Battery( id=b.id, label=b.label, brand=b.brand, status=b.status, device_id=b.device_id, notes=b.notes, size=b.size, chemistry=b.chemistry, capacity_mah=b.capacity_mah, tested_capacity_mah=b.tested_capacity_mah, tested_date=b.tested_date, charge_cycles=b.charge_cycles, purchase_date=b.purchase_date, storage_location=b.storage_location, battery_percentage=b.battery_percentage, )) dst.flush() # -- CapacityTest (FK → battery) -- print("Migrating capacity tests…") for ct in src_cap_tests: dst.add(CapacityTest( id=ct.id, battery_id=ct.battery_id, tested_capacity_mah=ct.tested_capacity_mah, tested_date=ct.tested_date, notes=ct.notes, )) dst.flush() # -- ChargeLog (FK → battery) -- print("Migrating charge logs…") for cl in src_charge_logs: dst.add(ChargeLog( id=cl.id, battery_id=cl.battery_id, charged_date=cl.charged_date, increment_cycles=cl.increment_cycles, notes=cl.notes, )) dst.flush() # -- BatteryPctLog (FK → battery) -- print("Migrating battery pct logs…") for pl in src_pct_logs: dst.add(BatteryPctLog( id=pl.id, battery_id=pl.battery_id, percentage=pl.percentage, recorded_at=pl.recorded_at, source=pl.source, )) dst.flush() dst.commit() print("Commit successful.\n") # -- Reset AUTO_INCREMENT -- table_max = { "device": max((d.id for d in src_devices), default=0), "battery": max((b.id for b in src_batteries), default=0), "capacity_test": max((c.id for c in src_cap_tests), default=0), "charge_log": max((c.id for c in src_charge_logs), default=0), "battery_pct_log": max((p.id for p in src_pct_logs), default=0), } with mariadb_engine.connect() as conn: for table, max_id in table_max.items(): conn.execute( text(f"ALTER TABLE {table} AUTO_INCREMENT = :v"), {"v": max_id + 1}, ) conn.commit() print("AUTO_INCREMENT counters reset.\n") # -- Verify counts -- counts = { "device": (len(src_devices), dst.query(Device).count()), "battery": (len(src_batteries), dst.query(Battery).count()), "capacity_test": (len(src_cap_tests), dst.query(CapacityTest).count()), "charge_log": (len(src_charge_logs), dst.query(ChargeLog).count()), "battery_pct_log": (len(src_pct_logs), dst.query(BatteryPctLog).count()), } print("=== Verification ===") print(f"{'Table':<20} {'SQLite':>8} {'MariaDB':>9} {'OK?':>6}") print("-" * 47) all_ok = True for table, (src_n, dst_n) in counts.items(): ok = src_n == dst_n all_ok = all_ok and ok print(f"{table:<20} {src_n:>8} {dst_n:>9} {'OK' if ok else 'MISMATCH':>6}") if not all_ok: print("\nERROR: Row count mismatch. Do not decommission SQLite.") sys.exit(1) print("\nMigration complete. All row counts match.") except Exception as exc: dst.rollback() print(f"\nERROR: {exc}") raise finally: src.close() dst.close() # --------------------------------------------------------------------------- # Post-migration instructions # --------------------------------------------------------------------------- def print_next_steps(mariadb_url: str) -> None: service_file = Path.home() / ".config/systemd/user/battery-tracker.service" print(""" === Environment configuration === Add to your .env file (or export in shell before starting the app): """) print(f" DATABASE_URL={mariadb_url}") print(f""" === systemd service file change === Edit: {service_file} In the [Service] section, add (or replace any existing DATABASE_URL line): Environment=DATABASE_URL={mariadb_url} Then reload and restart: systemctl --user daemon-reload systemctl --user restart battery-tracker """) # --------------------------------------------------------------------------- # Entry point # --------------------------------------------------------------------------- if __name__ == "__main__": print("=== Battery Tracker — MariaDB Setup ===\n") mariadb_url = collect_credentials() print("\nSnapshotting SQLite database…") snapshot_sqlite() migrate(mariadb_url) print_next_steps(mariadb_url)