Prechádzať zdrojové kódy

Add database migration functionality and initialize migration checks in TradingStats

- Introduced a new migration module to handle database schema updates, including adding new columns for P&L and risk metrics.
- Integrated migration checks in the TradingStats class to ensure the database schema is up-to-date before establishing a connection.
- Created a metadata table to track schema versions and ensure idempotent migrations.
- Enhanced logging for better traceability during migration processes.
Carles Sentis 3 dní pred
rodič
commit
54ad70c59d

+ 1 - 0
src/migrations/__init__.py

@@ -0,0 +1 @@
+ 

+ 120 - 0
src/migrations/migrate_db.py

@@ -0,0 +1,120 @@
+import sqlite3
+import os
+import logging
+
+# Configure logging
+logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
+logger = logging.getLogger(__name__)
+
+# Determine the absolute path to the project root directory
+# Adjusted for the script being in src/migrations/
+PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+DB_PATH = os.path.join(PROJECT_ROOT, "data", "trading_stats.sqlite")
+
+# MIGRATION_SET_VERSION identifies this specific group of schema changes.
+# We are bringing the schema to a state that includes the new P&L/risk columns.
+MIGRATION_SET_VERSION = 1
+
+def _get_db_connection(db_path):
+    """Gets a database connection."""
+    return sqlite3.connect(db_path)
+
+def _get_current_schema_version(conn):
+    """Retrieves the current schema version from the metadata table."""
+    try:
+        cursor = conn.cursor()
+        # Ensure metadata table exists (idempotent)
+        cursor.execute("CREATE TABLE IF NOT EXISTS metadata (key TEXT PRIMARY KEY, value TEXT)")
+        conn.commit()
+        
+        cursor.execute("SELECT value FROM metadata WHERE key = 'schema_version'")
+        row = cursor.fetchone()
+        return int(row[0]) if row else 0  # Default to 0 if no version is set
+    except sqlite3.Error as e:
+        logger.error(f"Error getting schema version: {e}. Assuming version 0.")
+        return 0
+
+def _set_schema_version(conn, version):
+    """Sets the schema version in the metadata table."""
+    try:
+        cursor = conn.cursor()
+        cursor.execute("INSERT OR REPLACE INTO metadata (key, value) VALUES ('schema_version', ?)", (str(version),))
+        conn.commit()
+        logger.info(f"Database schema version successfully set to {version}.")
+    except sqlite3.Error as e:
+        logger.error(f"Error setting schema version to {version}: {e}")
+
+def _column_exists(conn, table_name, column_name):
+    """Checks if a column exists in a table using PRAGMA table_info."""
+    cursor = conn.cursor()
+    cursor.execute(f"PRAGMA table_info({table_name})")
+    columns = [row[1] for row in cursor.fetchall()]
+    return column_name in columns
+
+def _add_column_if_not_exists(conn, table_name, column_name, column_definition):
+    """Adds a column to a table if it doesn't already exist."""
+    if not _column_exists(conn, table_name, column_name):
+        try:
+            cursor = conn.cursor()
+            query = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_definition}"
+            cursor.execute(query)
+            conn.commit()
+            logger.info(f"Successfully added column '{column_name}' to table '{table_name}'.")
+        except sqlite3.OperationalError as e:
+            # This specific error means the column might already exist (though _column_exists should prevent this)
+            # or some other operational issue.
+            if f"duplicate column name: {column_name}" in str(e).lower():
+                 logger.info(f"Column '{column_name}' effectively already exists in '{table_name}'. No action needed.")
+            else:
+                logger.error(f"Error adding column '{column_name}' to table '{table_name}': {e}")
+    else:
+        logger.info(f"Column '{column_name}' already exists in table '{table_name}'. No action taken.")
+
+def run_migrations():
+    """Runs the database migrations."""
+    logger.info(f"Attempting to migrate database at: {DB_PATH}")
+    if not os.path.exists(DB_PATH):
+        logger.info(f"Database file not found at {DB_PATH}. Nothing to migrate. "
+                    f"The application will create it with the latest schema on its next start.")
+        return
+
+    conn = None
+    try:
+        conn = _get_db_connection(DB_PATH)
+        current_db_version = _get_current_schema_version(conn)
+        logger.info(f"Current reported database schema version: {current_db_version}")
+
+        if current_db_version < MIGRATION_SET_VERSION:
+            logger.info(f"Schema version {current_db_version} is older than target migration set {MIGRATION_SET_VERSION}. Starting migration...")
+
+            # Define columns to add for MIGRATION_SET_VERSION 1
+            # These definitions match what's in the TradingStats._create_tables() method
+            columns_to_add = {
+                "liquidation_price": "REAL DEFAULT NULL",
+                "margin_used": "REAL DEFAULT NULL",
+                "leverage": "REAL DEFAULT NULL",
+                "position_value": "REAL DEFAULT NULL"
+            }
+
+            for col_name, col_def in columns_to_add.items():
+                _add_column_if_not_exists(conn, "trades", col_name, col_def)
+            
+            # After all operations for this version are successful, update the schema version in DB
+            _set_schema_version(conn, MIGRATION_SET_VERSION)
+            logger.info(f"Successfully migrated database to schema version {MIGRATION_SET_VERSION}.")
+        else:
+            logger.info(f"Database schema version {current_db_version} is already at or newer than migration set {MIGRATION_SET_VERSION}. No migration needed for this set.")
+
+    except sqlite3.Error as e:
+        logger.error(f"A database error occurred during migration: {e}")
+    except Exception as e:
+        logger.error(f"An unexpected error occurred during migration: {e}", exc_info=True)
+    finally:
+        if conn:
+            conn.close()
+            logger.info("Database connection closed.")
+
+if __name__ == "__main__":
+    logger.info("Starting database migration script...")
+    run_migrations()
+    logger.info("Database migration script finished.") 

+ 13 - 2
src/trading/trading_stats.py

@@ -14,6 +14,9 @@ import math
 from collections import defaultdict
 import uuid
 
+# 🆕 Import the migration runner
+from src.migrations.migrate_db import run_migrations as run_db_migrations
+
 logger = logging.getLogger(__name__)
 
 def _normalize_token_case(token: str) -> str:
@@ -34,10 +37,18 @@ class TradingStats:
         """Initialize the stats tracker and connect to SQLite DB."""
         self.db_path = db_path
         self._ensure_data_directory()
+        
+        # 🆕 Run database migrations before connecting and creating tables
+        # This ensures the schema is up-to-date when the connection is made
+        # and tables are potentially created for the first time.
+        logger.info("Running database migrations if needed...")
+        run_db_migrations() # Uses DB_PATH defined in migrate_db.py, which should be the same
+        logger.info("Database migration check complete.")
+        
         self.conn = sqlite3.connect(self.db_path, detect_types=sqlite3.PARSE_DECLTYPES | sqlite3.PARSE_COLNAMES)
         self.conn.row_factory = self._dict_factory
-        self._create_tables()
-        self._initialize_metadata()
+        self._create_tables() # CREATE IF NOT EXISTS will still be useful for first-time setup
+        self._initialize_metadata() # Also potentially sets schema_version if DB was just created
 
     def _dict_factory(self, cursor, row):
         """Convert SQLite rows to dictionaries."""