import logging
from typing import Dict, Any, List, Optional
from datetime import datetime, timezone
from telegram import Update
from telegram.ext import ContextTypes
from .base import InfoCommandsBase

logger = logging.getLogger(__name__)

class PositionsCommands(InfoCommandsBase):
    """Handles all position-related commands."""

    async def positions_command(self, update: Update, context: ContextTypes.DEFAULT_TYPE) -> None:
        """Handle the /positions command."""
        try:
            if not self._is_authorized(update):
                await self._reply(update, "āŒ Unauthorized access.")
                return

            stats = self.trading_engine.get_stats()
            if not stats:
                await self._reply(update, "āŒ Trading stats not available.")
                return

            # Get open positions from DB
            open_positions = stats.get_open_positions()
            if not open_positions:
                await self._reply(update, "šŸ“­ No open positions\n\nšŸ’” Use /long or /short to open a position")
                return

            # Get current exchange data for live ROE and mark prices
            exchange_positions = self.trading_engine.get_positions() or []
            exchange_orders = self.trading_engine.get_orders() or []
            
            # Create lookup for exchange data by symbol
            exchange_data_by_symbol = {}
            for ex_pos in exchange_positions:
                symbol_key = ex_pos.get('coin', '')
                if symbol_key:
                    exchange_data_by_symbol[symbol_key] = ex_pos

            # Initialize totals
            total_position_value = 0.0
            total_unrealized = 0.0
            total_margin_used = 0.0

            # Build position details
            positions_text = "šŸ“Š <b>Open Positions</b>\n\n"
            
            for position_trade in open_positions:
                try:
                    # Get position data with defaults
                    symbol = position_trade['symbol']
                    base_asset = symbol.split('/')[0] if '/' in symbol else symbol
                    position_side = position_trade.get('position_side', 'unknown')
                    
                    # Safely convert numeric values with proper null checks
                    entry_price = 0.0
                    if position_trade.get('entry_price') is not None:
                        try:
                            entry_price = float(position_trade['entry_price'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert entry_price for {symbol}")
                    
                    current_amount = 0.0
                    if position_trade.get('current_position_size') is not None:
                        try:
                            current_amount = float(position_trade['current_position_size'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert current_position_size for {symbol}")
                    
                    abs_current_amount = abs(current_amount)
                    trade_type = position_trade.get('trade_type', 'manual')
                    
                    # Calculate position duration
                    position_opened_at_str = position_trade.get('position_opened_at')
                    duration_str = "N/A"
                    if position_opened_at_str:
                        try:
                            opened_at_dt = datetime.fromisoformat(position_opened_at_str)
                            if opened_at_dt.tzinfo is None:
                                opened_at_dt = opened_at_dt.replace(tzinfo=timezone.utc)
                            now_utc = datetime.now(timezone.utc)
                            duration = now_utc - opened_at_dt
                            
                            days = duration.days
                            hours, remainder = divmod(duration.seconds, 3600)
                            minutes, _ = divmod(remainder, 60)
                            
                            parts = []
                            if days > 0:
                                parts.append(f"{days}d")
                            if hours > 0:
                                parts.append(f"{hours}h")
                            if minutes > 0 or (days == 0 and hours == 0):
                                parts.append(f"{minutes}m")
                            duration_str = " ".join(parts) if parts else "0m"
                        except ValueError:
                            logger.warning(f"Could not parse position_opened_at: {position_opened_at_str} for {symbol}")
                            duration_str = "Error"
                    
                    # Get price data with defaults - prioritize live exchange data
                    mark_price = entry_price  # Default to entry price
                    
                    # Try to get live mark price from exchange first
                    if exchange_data and exchange_data.get('markPrice') is not None:
                        try:
                            mark_price = float(exchange_data['markPrice'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert exchange mark_price for {symbol}")
                    # Fallback to database mark price
                    elif position_trade.get('mark_price') is not None:
                        try:
                            mark_price = float(position_trade['mark_price'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert database mark_price for {symbol}")
                    
                    # Calculate unrealized PnL
                    unrealized_pnl = 0.0
                    if position_trade.get('unrealized_pnl') is not None:
                        try:
                            unrealized_pnl = float(position_trade['unrealized_pnl'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert unrealized_pnl for {symbol}")
                    
                    # Get ROE from live exchange data (much more accurate)
                    roe_percentage = 0.0
                    exchange_data = exchange_data_by_symbol.get(base_asset)
                    if exchange_data and exchange_data.get('returnOnEquity') is not None:
                        try:
                            # Convert from decimal (0.118) to percentage (11.8%)
                            roe_percentage = float(exchange_data['returnOnEquity']) * 100
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert exchange ROE for {symbol}")
                    
                    # Fallback to database ROE if exchange data not available
                    if roe_percentage == 0.0 and position_trade.get('roe_percentage') is not None:
                        try:
                            roe_percentage = float(position_trade['roe_percentage'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert database roe_percentage for {symbol}")

                    # Add to totals
                    individual_position_value = 0.0
                    if position_trade.get('position_value') is not None:
                        try:
                            individual_position_value = float(position_trade['position_value'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert position_value for {symbol}")
                    
                    if individual_position_value <= 0:
                        individual_position_value = abs_current_amount * mark_price
                    
                    total_position_value += individual_position_value
                    total_unrealized += unrealized_pnl
                    
                    # Add margin to total
                    margin_used = 0.0
                    if position_trade.get('margin_used') is not None:
                        try:
                            margin_used = float(position_trade['margin_used'])
                        except (ValueError, TypeError):
                            logger.warning(f"Could not convert margin_used for {symbol}")
                    
                    if margin_used > 0:
                        total_margin_used += margin_used
                    
                    # --- Position Header Formatting (Emoji, Direction, Leverage) ---
                    pos_emoji = "🟢" if position_side == 'long' else "šŸ”“"
                    direction_text = position_side.upper()
                    
                    leverage = position_trade.get('leverage')
                    if leverage is not None:
                        try:
                            leverage_val = float(leverage)
                            leverage_str = f"x{leverage_val:.1f}".rstrip('0').rstrip('.') if '.' in f"{leverage_val:.1f}" else f"x{int(leverage_val)}"
                            direction_text = f"{direction_text} {leverage_str}"
                        except ValueError:
                            logger.warning(f"Could not parse leverage value: {leverage} for {symbol}")

                    # --- Format Output String ---
                    formatter = self._get_formatter()

                    # Get price precisions
                    entry_price_str = await formatter.format_price_with_symbol(entry_price, base_asset)
                    mark_price_str = await formatter.format_price_with_symbol(mark_price, base_asset)

                    # Get amount precision for position size
                    size_str = await formatter.format_amount(abs_current_amount, base_asset)

                    type_indicator = ""
                    if position_trade.get('trade_lifecycle_id'):
                        type_indicator = " šŸ¤–"
                    elif trade_type == 'external':
                        type_indicator = " šŸ”„"
                    
                    positions_text += f"{pos_emoji} <b>{base_asset} ({direction_text}){type_indicator}</b>\n"
                    positions_text += f"   šŸ“ Size: {size_str} {base_asset}\n"
                    positions_text += f"   šŸ’° Entry: {entry_price_str}\n"
                    positions_text += f"   ā³ Duration: {duration_str}\n"
                    
                    # Display individual position value
                    positions_text += f"   šŸ¦ Value: ${individual_position_value:,.2f}\n"
                    
                    if margin_used > 0:
                        positions_text += f"   šŸ’³ Margin: ${margin_used:,.2f}\n"

                    if mark_price > 0 and abs(mark_price - entry_price) > 1e-9:
                        positions_text += f"   šŸ“ˆ Mark: {mark_price_str}\n"
                    
                    pnl_emoji = "🟢" if unrealized_pnl >= 0 else "šŸ”“"
                    positions_text += f"   {pnl_emoji} P&L: ${unrealized_pnl:,.2f}\n"
                    
                    # Show ROE
                    roe_emoji = "🟢" if roe_percentage >= 0 else "šŸ”“"
                    positions_text += f"   {roe_emoji} ROE: {roe_percentage:+.2f}%\n"
                    
                    # Show exchange-provided risk data if available
                    if position_trade.get('liquidation_price') is not None and position_trade.get('liquidation_price') > 0:
                        liq_price_str = await formatter.format_price_with_symbol(position_trade.get('liquidation_price'), base_asset)
                        positions_text += f"   āš ļø Liquidation: {liq_price_str}\n"
                    
                    # Show stop loss if linked in database
                    if position_trade.get('stop_loss_price'):
                        sl_price = position_trade['stop_loss_price']
                        positions_text += f"   šŸ›‘ Stop Loss: {await formatter.format_price_with_symbol(sl_price, base_asset)}\n"
                    
                    # Show take profit if linked in database
                    if position_trade.get('take_profit_price'):
                        tp_price = position_trade['take_profit_price']
                        positions_text += f"   šŸŽÆ Take Profit: {await formatter.format_price_with_symbol(tp_price, base_asset)}\n"
                    
                    positions_text += f"   šŸ†” Lifecycle ID: {position_trade['trade_lifecycle_id'][:8]}\n\n"

                except Exception as e:
                    logger.error(f"Error processing position {position_trade.get('symbol', 'unknown')}: {e}")
                    continue

            # Calculate total unrealized P&L and total ROE
            total_unrealized_pnl = 0.0
            total_roe = 0.0
            for pos in open_positions:
                try:
                    size = float(pos.get('size', 0)) if pos.get('size') is not None else 0.0
                    entry_price = float(pos.get('entryPrice', 0)) if pos.get('entryPrice') is not None else 0.0
                    # Handle None markPrice safely
                    mark_price_raw = pos.get('markPrice')
                    mark_price = float(mark_price_raw) if mark_price_raw is not None else entry_price
                    roe = float(pos.get('roe_percentage', 0)) if pos.get('roe_percentage') is not None else 0.0
                    if size != 0 and entry_price != 0:
                        position_value = abs(size * entry_price)
                        total_unrealized_pnl += size * (mark_price - entry_price)
                        total_roe += roe * position_value
                        total_position_value += position_value
                except (ValueError, TypeError) as e:
                    logger.warning(f"Error calculating portfolio totals for position: {e}")
                    continue
            # Weighted average ROE
            avg_roe = (total_roe / total_position_value) if total_position_value > 0 else 0.0
            roe_emoji = "🟢" if avg_roe >= 0 else "šŸ”“"

            # Add portfolio summary
            portfolio_emoji = "🟢" if total_unrealized >= 0 else "šŸ”“"
            positions_text += f"šŸ’¼ <b>Total Portfolio:</b>\n"
            positions_text += f"   šŸ¦ Total Positions Value: ${total_position_value:,.2f}\n"
            if total_margin_used > 0:
                positions_text += f"   šŸ’³ Total Margin Used: ${total_margin_used:,.2f}\n"
                leverage_ratio = total_position_value / total_margin_used if total_margin_used > 0 else 1.0
                positions_text += f"   āš–ļø Portfolio Leverage: {leverage_ratio:.2f}x\n"
            positions_text += f"   {portfolio_emoji} Total Unrealized P&L: ${total_unrealized:,.2f}\n"
            if total_margin_used > 0:
                margin_pnl_percentage = (total_unrealized / total_margin_used) * 100
                positions_text += f"   šŸ“Š Portfolio Return: {margin_pnl_percentage:+.2f}% (on margin)\n"
            positions_text += "\n"
            positions_text += f"šŸ¤– <b>Legend:</b> šŸ¤– Bot-created • šŸ”„ External/synced • šŸ›”ļø External SL\n"
            positions_text += f"šŸ’” Use /sl [token] [price] or /tp [token] [price] to set risk management"

            await self._reply(update, positions_text.strip())

        except Exception as e:
            logger.error(f"Error in positions command: {e}")
            await self._reply(update, "āŒ Error retrieving position information.")

    def _get_external_stop_losses(self, symbol: str, position_side: str, entry_price: float, 
                                current_amount: float, exchange_orders: List[Dict[str, Any]]) -> List[float]:
        """Get external stop losses for a position."""
        external_sls = []
        for order in exchange_orders:
            try:
                order_symbol = order.get('symbol')
                order_side = order.get('side', '').lower()
                order_type = order.get('type', '').lower()
                order_price = float(order.get('price', 0))
                trigger_price = order.get('info', {}).get('triggerPrice')
                is_reduce_only = order.get('reduceOnly', False) or order.get('info', {}).get('reduceOnly', False)
                order_amount = float(order.get('amount', 0))
                
                if (order_symbol == symbol and is_reduce_only and 
                    abs(order_amount - current_amount) < 0.01 * current_amount):
                    
                    sl_trigger_price = 0
                    if trigger_price:
                        try:
                            sl_trigger_price = float(trigger_price)
                        except (ValueError, TypeError):
                            pass
                    if not sl_trigger_price and order_price > 0:
                        sl_trigger_price = order_price
                    
                    is_valid_sl = False
                    if position_side == 'long' and order_side == 'sell':
                        if sl_trigger_price > 0 and sl_trigger_price < entry_price:
                            is_valid_sl = True
                    elif position_side == 'short' and order_side == 'buy':
                        if sl_trigger_price > 0 and sl_trigger_price > entry_price:
                            is_valid_sl = True
                    
                    if is_valid_sl:
                        external_sls.append(sl_trigger_price)
                        
            except Exception as e:
                logger.warning(f"Error processing order for external SL: {e}")
                continue
                
        return external_sls