Ver Fonte

Update database migration logic to support multiple schema versions and enhance trading engine with token precision retrieval

- Refactored migration logic to define and apply multiple schema versions, including new columns for unrealized P&L percentage.
- Improved migration logging for better traceability and error handling.
- Added a new method in TradingEngine to fetch and cache token precision information, optimizing API calls and enhancing market data handling.
Carles Sentis há 3 dias atrás
pai
commit
ddf87032e6
2 ficheiros alterados com 112 adições e 24 exclusões
  1. 51 23
      src/migrations/migrate_db.py
  2. 61 1
      src/trading/trading_engine.py

+ 51 - 23
src/migrations/migrate_db.py

@@ -11,9 +11,27 @@ logger = logging.getLogger(__name__)
 PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
 DB_PATH = os.path.join(PROJECT_ROOT, "data", "trading_stats.sqlite")
 
-# MIGRATION_SET_VERSION identifies this specific group of schema changes.
-# We are bringing the schema to a state that includes the new P&L/risk columns.
-MIGRATION_SET_VERSION = 1
+# -----------------------------------------------------------------------------
+# MIGRATION DEFINITIONS
+# -----------------------------------------------------------------------------
+# Define the latest schema version the application expects.
+# Migrations will run until the DB schema_version matches this.
+LATEST_SCHEMA_VERSION = 2
+
+# Define columns for each migration version set
+# Version 1: Initial set of risk/P&L columns
+MIGRATION_SET_VERSION_1_COLUMNS = {
+    "liquidation_price": "REAL DEFAULT NULL",
+    "margin_used": "REAL DEFAULT NULL",
+    "leverage": "REAL DEFAULT NULL",
+    "position_value": "REAL DEFAULT NULL"
+}
+
+# Version 2: Added unrealized P&L percentage
+MIGRATION_SET_VERSION_2_COLUMNS = {
+    "unrealized_pnl_percentage": "REAL DEFAULT NULL"
+}
+# -----------------------------------------------------------------------------
 
 def _get_db_connection(db_path):
     """Gets a database connection."""
@@ -67,11 +85,12 @@ def _add_column_if_not_exists(conn, table_name, column_name, column_definition):
                  logger.info(f"Column '{column_name}' effectively already exists in '{table_name}'. No action needed.")
             else:
                 logger.error(f"Error adding column '{column_name}' to table '{table_name}': {e}")
+                raise # Re-raise to signal migration step failure
     else:
         logger.info(f"Column '{column_name}' already exists in table '{table_name}'. No action taken.")
 
 def run_migrations():
-    """Runs the database migrations."""
+    """Runs the database migrations sequentially."""
     logger.info(f"Attempting to migrate database at: {DB_PATH}")
     if not os.path.exists(DB_PATH):
         logger.info(f"Database file not found at {DB_PATH}. Nothing to migrate. "
@@ -84,27 +103,36 @@ def run_migrations():
         current_db_version = _get_current_schema_version(conn)
         logger.info(f"Current reported database schema version: {current_db_version}")
 
-        if current_db_version < MIGRATION_SET_VERSION:
-            logger.info(f"Schema version {current_db_version} is older than target migration set {MIGRATION_SET_VERSION}. Starting migration...")
-
-            # Define columns to add for MIGRATION_SET_VERSION 1
-            # These definitions match what's in the TradingStats._create_tables() method
-            columns_to_add = {
-                "liquidation_price": "REAL DEFAULT NULL",
-                "margin_used": "REAL DEFAULT NULL",
-                "leverage": "REAL DEFAULT NULL",
-                "position_value": "REAL DEFAULT NULL",
-                "unrealized_pnl_percentage": "REAL DEFAULT NULL"
-            }
-
-            for col_name, col_def in columns_to_add.items():
-                _add_column_if_not_exists(conn, "trades", col_name, col_def)
+        if current_db_version < LATEST_SCHEMA_VERSION:
+            logger.info(f"Schema version {current_db_version} is older than target latest schema version {LATEST_SCHEMA_VERSION}. Starting migration process...")
+
+            # --- Migration for Version 1 ---
+            if current_db_version < 1:
+                logger.info("Applying migration set for version 1...")
+                for col_name, col_def in MIGRATION_SET_VERSION_1_COLUMNS.items():
+                    _add_column_if_not_exists(conn, "trades", col_name, col_def)
+                _set_schema_version(conn, 1)
+                current_db_version = 1 # Update current_db_version after successful migration
+                logger.info("Successfully migrated database to schema version 1.")
+            
+            # --- Migration for Version 2 ---
+            if current_db_version < 2:
+                logger.info("Applying migration set for version 2...")
+                for col_name, col_def in MIGRATION_SET_VERSION_2_COLUMNS.items():
+                    _add_column_if_not_exists(conn, "trades", col_name, col_def)
+                _set_schema_version(conn, 2)
+                current_db_version = 2 # Update current_db_version
+                logger.info("Successfully migrated database to schema version 2.")
             
-            # After all operations for this version are successful, update the schema version in DB
-            _set_schema_version(conn, MIGRATION_SET_VERSION)
-            logger.info(f"Successfully migrated database to schema version {MIGRATION_SET_VERSION}.")
+            # Add more migration blocks here for future versions (e.g., if current_db_version < 3:)
+
+            if current_db_version == LATEST_SCHEMA_VERSION:
+                 logger.info(f"All migrations applied. Database is now at schema version {current_db_version}.")
+            else:
+                 logger.warning(f"Migration process completed, but DB schema version is {current_db_version}, expected {LATEST_SCHEMA_VERSION}. Check migration logic.")
+
         else:
-            logger.info(f"Database schema version {current_db_version} is already at or newer than migration set {MIGRATION_SET_VERSION}. No migration needed for this set.")
+            logger.info(f"Database schema version {current_db_version} is already at or newer than latest target schema {LATEST_SCHEMA_VERSION}. No migration needed.")
 
     except sqlite3.Error as e:
         logger.error(f"A database error occurred during migration: {e}")

+ 61 - 1
src/trading/trading_engine.py

@@ -7,7 +7,7 @@ import os
 import json
 import logging
 from typing import Dict, Any, Optional, Tuple, List
-from datetime import datetime
+from datetime import datetime, timezone
 import uuid # For generating unique bot_order_ref_ids
 
 from src.config.config import Config
@@ -110,6 +110,66 @@ class TradingEngine:
         """Get market data for a symbol."""
         return self.client.get_market_data(symbol)
     
+    # 🆕 Method to get token precision info
+    _markets_cache: Optional[List[Dict[str, Any]]] = None
+    _markets_cache_timestamp: Optional[datetime] = None
+
+    def get_token_info(self, base_asset: str) -> Dict[str, Any]:
+        """Fetch (and cache) market data to find precision for a given base_asset."""
+        # Cache markets for 1 hour to avoid frequent API calls
+        if self._markets_cache is None or \
+           (self._markets_cache_timestamp and 
+            (datetime.now(timezone.utc) - self._markets_cache_timestamp).total_seconds() > 3600):
+            try:
+                logger.info("Fetching and caching markets for token info...")
+                markets_data = self.client.get_markets() # This returns a list of market dicts
+                if markets_data:
+                    self._markets_cache = markets_data
+                    self._markets_cache_timestamp = datetime.now(timezone.utc)
+                    logger.info(f"Successfully cached {len(self._markets_cache)} markets.")
+                else:
+                    logger.warning("get_markets() returned no data. Using empty cache.")
+                    self._markets_cache = [] # Set to empty list to avoid re-fetching immediately
+                    self._markets_cache_timestamp = datetime.now(timezone.utc)
+
+            except Exception as e:
+                logger.error(f"Error fetching markets for token info: {e}. Will use defaults.")
+                self._markets_cache = [] # Prevent re-fetching on immediate subsequent calls
+                self._markets_cache_timestamp = datetime.now(timezone.utc)
+        
+        default_precision = {'amount': 6, 'price': 2} # Default if not found
+        target_symbol_prefix = f"{base_asset.upper()}/"
+
+        if self._markets_cache:
+            for market_details in self._markets_cache:
+                symbol = market_details.get('symbol')
+                if symbol and symbol.upper().startswith(target_symbol_prefix):
+                    precision = market_details.get('precision')
+                    if precision and isinstance(precision, dict) and \
+                       'amount' in precision and 'price' in precision:
+                        logger.debug(f"Found precision for {base_asset}: {precision}")
+                        return {
+                            'precision': precision,
+                            'base_precision': precision.get('amount'), # For direct access
+                            'quote_precision': precision.get('price')  # For direct access
+                        }
+                    else:
+                        logger.warning(f"Market {symbol} found for {base_asset}, but precision data is missing or malformed: {precision}")
+                        return { # Return default but log that market was found
+                            'precision': default_precision,
+                            'base_precision': default_precision['amount'],
+                            'quote_precision': default_precision['price']
+                        }
+            logger.warning(f"No market symbol starting with '{target_symbol_prefix}' found in cached markets for {base_asset}.")
+        else:
+            logger.warning("Markets cache is empty, cannot find token info.")
+
+        return { # Fallback to defaults
+            'precision': default_precision,
+            'base_precision': default_precision['amount'],
+            'quote_precision': default_precision['price']
+        }
+
     def find_position(self, token: str) -> Optional[Dict[str, Any]]:
         """Find an open position for a token."""
         symbol = f"{token}/USDC:USDC"