diff --git a/app.py b/app.py index b648ea0..6d7c7b0 100644 --- a/app.py +++ b/app.py @@ -3,11 +3,22 @@ from sqlalchemy import create_engine, func from sqlalchemy.orm import scoped_session, sessionmaker from datetime import datetime, date, timedelta -import json +import re from models import Base, Battery, BatteryPctLog, Device, CapacityTest, ChargeLog +def _parse_date(val: str) -> str | None: + """Return val if it is a valid YYYY-MM-DD string, else None.""" + if not val: + return None + try: + datetime.strptime(val, "%Y-%m-%d") + return val + except ValueError: + return None + + def create_app(config_object="config"): app = Flask(__name__) app.config.from_object(config_object) @@ -28,6 +39,9 @@ def create_app(config_object="config"): # Home Assistant integration (optional) # ------------------------------------------------------------------ # + from flask_wtf.csrf import CSRFProtect + CSRFProtect(app) + from ha_client import HomeAssistantClient from ha_poller import HaPoller @@ -109,9 +123,14 @@ def create_app(config_object="config"): purchase_date = f.get("purchase_date", "").strip() or None storage_location = f.get("storage_location", "").strip() or None - existing = db.query(func.count(Battery.id)).filter_by(brand=brand).scalar() + existing_labels = [ + r[0] for r in db.query(Battery.label).filter(Battery.brand == brand).all() + ] + nums = [int(m.group(1)) for lbl in existing_labels + if (m := re.search(r'(\d+)$', lbl))] + next_num = max(nums, default=0) for i in range(count): - label = f"{brand} {existing + i + 1:03d}" + label = f"{brand} {next_num + i + 1:03d}" db.add(Battery(label=label, brand=brand, status="available", notes=notes, size=size, chemistry=chemistry, capacity_mah=capacity_mah, purchase_date=purchase_date, storage_location=storage_location)) @@ -160,26 +179,26 @@ def create_app(config_object="config"): .filter_by(battery_id=battery_id) .order_by(BatteryPctLog.recorded_at.desc()) .all()) - charge_logs_json = json.dumps([ + charge_logs_data = [ {"id": l.id, "date": l.charged_date, "cycles": l.increment_cycles, "notes": l.notes or ""} for l in charge_logs - ]) - capacity_tests_json = json.dumps([ + ] + capacity_tests_data = [ {"id": t.id, "date": t.tested_date, "mah": t.tested_capacity_mah, "notes": t.notes or ""} for t in sorted(capacity_tests, key=lambda t: (t.tested_date, t.id), reverse=True) - ]) - pct_logs_json = json.dumps([ + ] + pct_logs_data = [ {"recorded_at": str(l.recorded_at), "pct": l.percentage, "source": l.source or ""} for l in pct_logs - ]) + ] return render_template("battery_detail.html", battery=battery, storage_locations=storage_locations, capacity_tests=capacity_tests, charge_logs=charge_logs, pct_logs=pct_logs, - charge_logs_json=charge_logs_json, - capacity_tests_json=capacity_tests_json, - pct_logs_json=pct_logs_json) + charge_logs_data=charge_logs_data, + capacity_tests_data=capacity_tests_data, + pct_logs_data=pct_logs_data) # ------------------------------------------------------------------ # # Battery — edit notes @@ -204,7 +223,8 @@ def create_app(config_object="config"): battery.chemistry = f.get("chemistry", "").strip() or None battery.capacity_mah = _int("capacity_mah") battery.charge_cycles = _int("charge_cycles") - battery.purchase_date = f.get("purchase_date", "").strip() or None + purchase_raw = f.get("purchase_date", "").strip() + battery.purchase_date = _parse_date(purchase_raw) if purchase_raw else None battery.storage_location = f.get("storage_location", "").strip() or None new_pct = _int("battery_percentage") if new_pct != battery.battery_percentage: @@ -239,10 +259,10 @@ def create_app(config_object="config"): if battery is None: abort(404) mah_raw = request.form.get("tested_capacity_mah", "").strip() - date_val = request.form.get("tested_date", "").strip() + date_val = _parse_date(request.form.get("tested_date", "").strip()) notes = request.form.get("notes", "").strip() or None if not mah_raw or not date_val: - flash("Capacity (mAh) and date are required.", "error") + flash("Capacity (mAh) and a valid date (YYYY-MM-DD) are required.", "error") return redirect(url_for("battery_detail", battery_id=battery_id)) try: mah = int(mah_raw) @@ -281,9 +301,9 @@ def create_app(config_object="config"): battery = db.get(Battery, battery_id) if battery is None: abort(404) - date_val = request.form.get("charged_date", "").strip() + date_val = _parse_date(request.form.get("charged_date", "").strip()) if not date_val: - flash("Date is required.", "error") + flash("A valid date (YYYY-MM-DD) is required.", "error") return redirect(url_for("battery_detail", battery_id=battery_id)) increment = 1 if request.form.get("increment_cycles") else 0 notes = request.form.get("notes", "").strip() or None @@ -544,9 +564,9 @@ def create_app(config_object="config"): label = field_name.replace("_", " ").title() flash(f"Set {label} on {n} batter{'y' if n == 1 else 'ies'}.", "success") elif action == "log_charged": - date_val = request.form.get("charged_date", "").strip() + date_val = _parse_date(request.form.get("charged_date", "").strip()) if not date_val: - flash("Date is required.", "error") + flash("A valid date (YYYY-MM-DD) is required.", "error") return redirect(url_for("dashboard")) increment = 1 if request.form.get("increment_cycles") else 0 for b in batteries: @@ -857,4 +877,4 @@ def create_app(config_object="config"): app = create_app() if __name__ == "__main__": - app.run(debug=True) + app.run(debug=False) diff --git a/config.py b/config.py index 87c5cf9..26b788c 100644 --- a/config.py +++ b/config.py @@ -1,10 +1,19 @@ import os +import logging SQLALCHEMY_DATABASE_URI = os.environ.get( "DATABASE_URL", "sqlite:///batteries.db", ) -SECRET_KEY = os.environ.get("SECRET_KEY", "dev-secret-change-in-prod") + +_secret_key = os.environ.get("SECRET_KEY") +if not _secret_key: + logging.warning( + "SECRET_KEY not set — using insecure default. " + "Set SECRET_KEY env var before running in production." + ) +SECRET_KEY = _secret_key or "dev-secret-change-in-prod" + SQLALCHEMY_TRACK_MODIFICATIONS = False # Home Assistant integration (all optional — app works normally when absent) diff --git a/requirements.txt b/requirements.txt index b866a6e..d3e8069 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ Flask>=3.0,<4.0 +Flask-WTF>=1.2,<2.0 SQLAlchemy>=2.0,<3.0 PyMySQL>=1.1,<2.0 waitress>=3.0,<4.0 diff --git a/templates/base.html b/templates/base.html index 07bec87..68f3df1 100644 --- a/templates/base.html +++ b/templates/base.html @@ -349,6 +349,18 @@ navigator.serviceWorker.register('/sw.js'); } + // Inject CSRF token into all POST forms + document.addEventListener('DOMContentLoaded', function() { + var token = '{{ csrf_token() }}'; + document.querySelectorAll('form').forEach(function(form) { + if (form.method.toLowerCase() === 'post') { + var inp = document.createElement('input'); + inp.type = 'hidden'; inp.name = 'csrf_token'; inp.value = token; + form.appendChild(inp); + } + }); + }); + (function() { var modal = document.getElementById('confirm-modal'); var msgEl = document.getElementById('confirm-modal-msg'); diff --git a/templates/battery_detail.html b/templates/battery_detail.html index c7bfef3..e4abf9f 100644 --- a/templates/battery_detail.html +++ b/templates/battery_detail.html @@ -411,7 +411,7 @@ function metaSelectChanged(sel, inputId) { (function() { var canvas = document.getElementById('pct-chart'); if (!canvas) return; - var rawLogs = {{ pct_logs_json | safe }}; + var rawLogs = {{ pct_logs_data | tojson }}; // pct_logs_json is ordered newest-first; chart wants oldest-first var logsAsc = rawLogs.slice().reverse(); var vals = logsAsc.map(function(l) { return l.pct; }); @@ -472,14 +472,15 @@ function metaSelectChanged(sel, inputId) { @@ -549,6 +550,14 @@ function metaSelectChanged(sel, inputId) { var HIST_PAGE = 20; var _batteryId = {{ battery.id }}; +function escHtml(s) { + return String(s == null ? '' : s) + .replace(/&/g, '&') + .replace(//g, '>') + .replace(/"/g, '"'); +} + function openHistModal(id) { document.getElementById(id).classList.add('open'); } @@ -618,7 +627,7 @@ function makeModal(cfg) { } // ── Capacity modal ──────────────────────────────────────────────────────── -var _capAll = {{ capacity_tests_json | safe }}; +var _capAll = {{ capacity_tests_data | tojson }}; var capModal = makeModal({ all: _capAll, bodyId: 'cap-modal-body', prevId: 'cap-prev', nextId: 'cap-next', pageInfoId: 'cap-page-info', @@ -626,9 +635,9 @@ var capModal = makeModal({ thead: 'DateCapacityNotes', renderRow: function(r) { return '' + - '' + r.date + '' + + '' + escHtml(r.date) + '' + '' + r.mah + ' mAh' + - '' + (r.notes || '—') + '' + + '' + (escHtml(r.notes) || '—') + '' + '' + '
Date+CycleNotes', renderRow: function(r) { return '' + - '' + r.date + '' + + '' + escHtml(r.date) + '' + '' + (r.cycles ? '✓' : '—') + '' + - '' + (r.notes || '—') + '' + + '' + (escHtml(r.notes) || '—') + '' + '' + '⚠ ' + r.pct + '%' : r.pct + '%'; return '' + - '' + r.recorded_at + '' + + '' + escHtml(r.recorded_at) + '' + '' + pctHtml + '' + - '' + (r.source || '—') + '' + + '' + (escHtml(r.source) || '—') + '' + ''; } }); diff --git a/tests/conftest.py b/tests/conftest.py index 3487106..5567822 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,6 +15,7 @@ def app(): SECRET_KEY = "test-secret" SQLALCHEMY_TRACK_MODIFICATIONS = False TESTING = True + WTF_CSRF_ENABLED = False # disable CSRF validation in tests HOMEASSISTANT_URL = None # prevent HA poller from starting in tests HOMEASSISTANT_API_KEY = None diff --git a/tests/test_ha_integration.py b/tests/test_ha_integration.py index 7586300..2c639f2 100644 --- a/tests/test_ha_integration.py +++ b/tests/test_ha_integration.py @@ -33,6 +33,7 @@ def ha_app(): SECRET_KEY = "test-secret" SQLALCHEMY_TRACK_MODIFICATIONS = False TESTING = True + WTF_CSRF_ENABLED = False # disable CSRF validation in tests HOMEASSISTANT_URL = "http://ha.test:8123" HOMEASSISTANT_API_KEY = "fake-token" HOMEASSISTANT_POLL_INTERVAL = 300