import os
import json
import time
import threading
import logging
from datetime import datetime, timedelta
from zoneinfo import ZoneInfo

EST = ZoneInfo('America/New_York')

import numpy as np
import pandas as pd
import alpaca_trade_api as tradeapi
from alpaca_trade_api.rest import TimeFrame
from dotenv import load_dotenv
from openai import OpenAI

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

LOG_FILE       = os.path.join(os.path.dirname(__file__), 'trade_log.json')
WATCHLIST_FILE = os.path.join(os.path.dirname(__file__), 'watchlist.json')
LOG_MAX        = 500

DEFAULT_WATCHLIST = [
    # Large cap tech
    'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'NVDA', 'META', 'TSLA', 'AMD',
    'INTC', 'MU', 'QCOM', 'AMAT', 'SMCI', 'ROKU', 'PINS', 'RDDT',
    # ETFs
    'SPY', 'QQQ', 'ARKK', 'SOXL',
    # High volatility / dip candidates
    'COIN', 'PLTR', 'RIVN', 'SOFI', 'MARA', 'UBER', 'BA', 'F',
    'SNAP', 'HOOD', 'RBLX', 'LCID', 'NIO', 'DKNG', 'PENN', 'PLUG',
    # Crypto-adjacent
    'RIOT', 'BTBT', 'HUT', 'CIFR', 'CLSK',
    # Chinese tech (very volatile)
    'BABA', 'JD', 'BIDU', 'PDD', 'XPEV', 'LI',
    # Clean energy
    'ENPH', 'FSLR', 'BE', 'CHPT', 'SPWR', 'ARRY',
    # Fintech
    'AFRM', 'UPST', 'OPEN', 'LC',
    # Biotech (high volatility)
    'MRNA', 'NVAX', 'SAVA', 'BBIO',
    # AI plays
    'SOUN', 'IONQ', 'BBAI', 'AI', 'BIGC',
    # Other volatile
    'GME', 'AMC', 'SKLZ', 'CLOV',
    'SPCE', 'MAXN', 'HYLN',
]


def _load_log_from_disk() -> list:
    try:
        if os.path.exists(LOG_FILE):
            with open(LOG_FILE, 'r') as f:
                return json.load(f)
    except Exception:
        pass
    return []


def _save_log_to_disk(log: list):
    try:
        with open(LOG_FILE, 'w') as f:
            json.dump(log[:LOG_MAX], f)
    except Exception as e:
        logging.warning(f'Could not save log to disk: {e}')


def _load_watchlist_from_disk() -> list:
    try:
        if os.path.exists(WATCHLIST_FILE):
            with open(WATCHLIST_FILE, 'r') as f:
                existing = json.load(f)
            # Merge any new default symbols not already in the list
            merged = existing + [s for s in DEFAULT_WATCHLIST if s not in existing]
            if len(merged) != len(existing):
                _save_watchlist_to_disk(merged)
            return merged
    except Exception:
        pass
    return DEFAULT_WATCHLIST[:]


def _save_watchlist_to_disk(watchlist: list):
    try:
        with open(WATCHLIST_FILE, 'w') as f:
            json.dump(watchlist, f)
    except Exception as e:
        logging.warning(f'Could not save watchlist to disk: {e}')

# ---------------------------------------------------------------------------
# Shared state
# ---------------------------------------------------------------------------
bot_state = {
    'running': False,
    'thread': None,
    'trade_log': _load_log_from_disk(),
    'watchlist': _load_watchlist_from_disk(),
    'last_scan': None,
    'scan_count': 0,
}

SCAN_INTERVAL       = 900    # 15 minutes
RSI_PERIOD          = 14
SMA_PERIOD          = 20
MACD_FAST           = 12
MACD_SLOW           = 26
MACD_SIGNAL         = 9
RSI_BUY             = 70     # pre-filter only — AI makes final call
RSI_SELL            = 75     # pre-filter only — AI makes final call
MIN_VOLUME_MA       = 20
MIN_VOLUME_MULT     = 0.25   # very loose — just block completely dead stocks
TRAIL_PCT           = 0.08   # 8% trailing stop — gives volatile stocks room to breathe
MAX_POSITION_PCT    = 0.20


# ---------------------------------------------------------------------------
# Alpaca API helper
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# AI decision layer
# ---------------------------------------------------------------------------
def ai_decision(symbol: str, action: str, sig: dict, extra: str = '') -> tuple:
    """
    Ask GPT-4o-mini whether to BUY or SELL a position.
    Returns (True/False, reason_string).
    action: 'buy' or 'sell'
    """
    try:
        client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

        action_verb = 'BUY INTO' if action == 'buy' else 'SELL AND EXIT'

        prompt = f"""You are an algorithmic trading AI managing a real-money brokerage account ($500 total).
Should I {action_verb} {symbol} RIGHT NOW?

Technical data:
- Current Price: ${sig['price']:.2f}
- RSI(14): {sig['rsi']:.1f} — oversold <30, neutral 30-60, overbought >70
- SMA(20): {sig['sma']:.2f} — price is {"ABOVE (bullish)" if sig['price'] > sig['sma'] else "BELOW (bearish)"}
- MACD Histogram: {sig['macd_hist']:.4f} ({"bullish momentum" if sig['macd_hist'] > 0 else "bearish momentum"})
- Volume: {"Normal/High" if sig['volume_ok'] else "Below average"}
- Hourly trend: {"Uptrend" if sig['hourly_up'] else "Downtrend"}
{extra}

Answer ONLY with one of these two formats — no other text:
APPROVE: [one sentence explaining why I should {action_verb} now]
REJECT: [one sentence explaining why I should NOT {action_verb} now]"""

        response = client.chat.completions.create(
            model='gpt-4o-mini',
            messages=[{'role': 'user', 'content': prompt}],
            max_tokens=200,
            temperature=0.2,
        )

        text     = response.choices[0].message.content.strip()
        approved = text.upper().startswith('APPROVE')
        detail   = f"=== AI Analysis for {symbol} ({action.upper()}) ===\n\n{prompt}\n\n=== GPT-4o-mini Response ===\n\n{text}"
        return approved, text, detail

    except Exception as e:
        log_event(f'AI decision failed for {symbol}: {e} — falling back to rules', level='warning')
        return True, 'AI unavailable — rules only', None


def get_api() -> tradeapi.REST:
    return tradeapi.REST(
        key_id=os.getenv('APCA_API_KEY_ID'),
        secret_key=os.getenv('APCA_API_SECRET_KEY'),
        base_url=os.getenv('APCA_API_BASE_URL', 'https://paper-api.alpaca.markets'),
    )


# ---------------------------------------------------------------------------
# Indicators
# ---------------------------------------------------------------------------
def calculate_rsi(closes: pd.Series, period: int = RSI_PERIOD) -> pd.Series:
    delta    = closes.diff()
    gain     = delta.clip(lower=0)
    loss     = -delta.clip(upper=0)
    avg_gain = gain.rolling(window=period, min_periods=period).mean()
    avg_loss = loss.rolling(window=period, min_periods=period).mean()
    rs  = avg_gain / avg_loss.replace(0, np.nan)
    return 100 - (100 / (1 + rs))


def calculate_sma(closes: pd.Series, period: int = SMA_PERIOD) -> pd.Series:
    return closes.rolling(window=period, min_periods=period).mean()


def calculate_macd(closes: pd.Series):
    """Returns (macd_line, signal_line, histogram) as latest float values."""
    ema_fast   = closes.ewm(span=MACD_FAST,   adjust=False).mean()
    ema_slow   = closes.ewm(span=MACD_SLOW,   adjust=False).mean()
    macd_line  = ema_fast - ema_slow
    signal     = macd_line.ewm(span=MACD_SIGNAL, adjust=False).mean()
    histogram  = macd_line - signal
    return macd_line.iloc[-1], signal.iloc[-1], histogram.iloc[-1]


def volume_ok(volumes: pd.Series) -> bool:
    """True if latest volume is at least MIN_VOLUME_MULT × recent average."""
    if len(volumes) < MIN_VOLUME_MA:
        return True   # not enough data — don't block
    avg = volumes.iloc[-MIN_VOLUME_MA - 1:-1].mean()
    return volumes.iloc[-1] >= avg * MIN_VOLUME_MULT


# ---------------------------------------------------------------------------
# Fetch bars — daily + hourly for multi-timeframe
# ---------------------------------------------------------------------------
def get_signals(symbol: str):
    """
    Returns dict with all signal values, or None on failure.
    Keys: rsi, sma, macd, macd_signal, macd_hist, price,
          volume_ok, hourly_trend_up
    """
    api   = get_api()
    end   = datetime.utcnow()
    start = end - timedelta(days=150)

    try:
        # --- Daily bars ---
        bars = api.get_bars(
            symbol, TimeFrame.Day,
            start=start.strftime('%Y-%m-%d'),
            end=end.strftime('%Y-%m-%d'),
            limit=120,
            feed='iex',
        ).df

        if bars.empty or len(bars) < SMA_PERIOD + RSI_PERIOD:
            log_event(f'Not enough daily bars for {symbol} ({len(bars)})', level='warning')
            return None

        closes  = bars['close']
        volumes = bars['volume']

        rsi            = calculate_rsi(closes).iloc[-1]
        sma            = calculate_sma(closes).iloc[-1]
        macd, sig, hist = calculate_macd(closes)
        price          = closes.iloc[-1]
        vol_ok         = volume_ok(volumes)

        if pd.isna(rsi) or pd.isna(sma):
            return None

        # --- Hourly bars for trend confirmation ---
        h_start = end - timedelta(days=5)
        try:
            h_bars = api.get_bars(
                symbol, TimeFrame.Hour,
                start=h_start.strftime('%Y-%m-%d'),
                end=end.strftime('%Y-%m-%d'),
                limit=120,
                feed='iex',
            ).df
            if len(h_bars) >= 8:
                h_sma_fast = h_bars['close'].rolling(4).mean().iloc[-1]
                h_sma_slow = h_bars['close'].rolling(8).mean().iloc[-1]
                hourly_up  = h_sma_fast > h_sma_slow
            else:
                hourly_up  = True   # not enough data — don't block
        except Exception:
            hourly_up = True

        return {
            'rsi':          rsi,
            'sma':          sma,
            'macd':         macd,
            'macd_signal':  sig,
            'macd_hist':    hist,
            'price':        price,
            'volume_ok':    vol_ok,
            'hourly_up':    hourly_up,
        }

    except Exception as e:
        log_event(f'Failed to fetch bars for {symbol}: {e}', level='error')
        return None


# ---------------------------------------------------------------------------
# Trailing stop tracker
# ---------------------------------------------------------------------------
_trail_highs: dict = {}   # symbol -> highest price seen since entry


def update_trail(symbol: str, current_price: float):
    if symbol not in _trail_highs:
        _trail_highs[symbol] = current_price
    else:
        _trail_highs[symbol] = max(_trail_highs[symbol], current_price)


def trail_stop_hit(symbol: str, current_price: float) -> bool:
    if symbol not in _trail_highs:
        return False
    stop = _trail_highs[symbol] * (1 - TRAIL_PCT)
    return current_price <= stop


def clear_trail(symbol: str):
    _trail_highs.pop(symbol, None)


# ---------------------------------------------------------------------------
# Main scan loop
# ---------------------------------------------------------------------------
def scan_and_trade():
    log_event('Bot started — scanning every 15 minutes')
    while bot_state['running']:
        try:
            _run_scan()
        except Exception as e:
            log_event(f'Unhandled scan error: {e}', level='error')

        for _ in range(SCAN_INTERVAL):
            if not bot_state['running']:
                break
            time.sleep(1)

    log_event('Bot loop exited')


def _run_scan():
    api = get_api()

    try:
        clock = api.get_clock()
    except Exception as e:
        log_event(f'Could not fetch market clock: {e}', level='error')
        return

    if not clock.is_open:
        log_event('Market closed — skipping scan')
        bot_state['last_scan'] = datetime.now(EST).isoformat()
        return

    try:
        account = api.get_account()
    except Exception as e:
        log_event(f'Could not fetch account: {e}', level='error')
        return

    equity           = float(account.equity)
    max_position_val = equity * MAX_POSITION_PCT

    try:
        positions = {p.symbol: p for p in api.list_positions()}
    except Exception as e:
        log_event(f'Could not fetch positions: {e}', level='error')
        return

    bot_state['last_scan']  = datetime.now(EST).isoformat()
    bot_state['scan_count'] += 1
    log_event(
        f'Scan #{bot_state["scan_count"]} — equity ${equity:,.2f} | '
        f'{len(positions)} open position(s) | watching {len(bot_state["watchlist"])} symbol(s)'
    )

    # --- SELL checks ---
    for symbol, pos in list(positions.items()):
        try:
            sig = get_signals(symbol)
            if sig is None:
                continue

            price = sig['price']
            update_trail(symbol, price)

            entry  = float(pos.avg_entry_price)
            pnl    = (price - entry) / entry

            # Hard stop — always sell, no AI override
            if trail_stop_hit(symbol, price):
                _submit_sell(api, symbol, pos.qty, 'trail-stop hit', pnl)
                clear_trail(symbol)
                log_event(f'HARD STOP SELL {symbol} — trail stop hit  PnL={pnl:+.2%}', level='trade')
                continue

            # Only evaluate selling if there's a real technical reason
            sell_signals = []
            if sig['rsi'] > 75:
                sell_signals.append(f'RSI extremely overbought ({sig["rsi"]:.1f})')
            if sig['macd_hist'] < 0 and sig['macd'] < sig['macd_signal'] and pnl > 0:
                sell_signals.append('MACD turned bearish on a winning position')
            if pnl <= -0.03:
                sell_signals.append(f'losing position ({pnl:+.2%})')

            if not sell_signals:
                log_event(f'HOLD {symbol} — no sell signals  PnL={pnl:+.2%}')
                continue

            # AI makes the final call on whether to sell
            approved, reason, detail = ai_decision(symbol, 'sell', sig,
                extra=f'- Entry price: ${entry:.2f}\n- Current PnL: {pnl:+.2%}\n- Sell signals triggered: {", ".join(sell_signals)}')
            if approved:
                _submit_sell(api, symbol, pos.qty, ' + '.join(sell_signals), pnl)
                clear_trail(symbol)
                log_event(f'AI APPROVED SELL {symbol} — {reason}', level='trade', detail=detail)
            else:
                log_event(f'AI OVERRIDING HOLD {symbol} — {reason}', level='info', detail=detail)

        except Exception as e:
            log_event(f'Error in sell check for {symbol}: {e}', level='error')

    # --- BUY checks ---
    buying_power = float(account.buying_power)
    if buying_power < 10:
        log_event(f'Insufficient buying power (${buying_power:.2f}) — skipping buy scans')
        bot_state['last_scan'] = datetime.now(EST).isoformat()
        return

    for symbol in bot_state['watchlist']:
        if symbol in positions:
            continue

        try:
            sig = get_signals(symbol)
            if sig is None:
                continue

            # Basic pre-filter — just block extreme overbought and dead volume
            rsi_ok = sig['rsi'] < RSI_BUY
            vol_ok = sig['volume_ok']

            if not rsi_ok:
                log_event(f'SKIP {symbol} — RSI={sig["rsi"]:.1f} extremely overbought')
                continue
            if not vol_ok:
                log_event(f'SKIP {symbol} — volume too low to trade')
                continue

            # AI makes the real decision
            approved, reason, detail = ai_decision(symbol, 'buy', sig)
            if approved:
                qty = int(max_position_val / sig['price'])
                if qty > 0:
                    _submit_buy(api, symbol, qty, sig)
                    log_event(f'AI APPROVED BUY {symbol} — {reason}', level='trade', detail=detail)
            else:
                log_event(f'AI SKIP {symbol} — {reason}', level='info', detail=detail)

        except Exception as e:
            log_event(f'Error in buy check for {symbol}: {e}', level='error')


def _submit_buy(api, symbol, qty, sig):
    try:
        api.submit_order(
            symbol=symbol, qty=qty, side='buy',
            type='market', time_in_force='day',
        )
        _trail_highs[symbol] = sig['price']
        log_event(
            f'BUY  {symbol:6s}  qty={qty}  price=${sig["price"]:.2f}  '
            f'RSI={sig["rsi"]:.1f}  MACD_hist={sig["macd_hist"]:.3f}  '
            f'SMA={sig["sma"]:.2f}',
            level='trade'
        )
    except Exception as e:
        log_event(f'Order failed (BUY {symbol}): {e}', level='error')


def _submit_sell(api, symbol, qty, reason, pnl_pct):
    try:
        api.submit_order(
            symbol=symbol, qty=qty, side='sell',
            type='market', time_in_force='day',
        )
        log_event(
            f'SELL {symbol:6s}  qty={qty}  reason={reason}  PnL={pnl_pct:+.2%}',
            level='trade'
        )
    except Exception as e:
        log_event(f'Order failed (SELL {symbol}): {e}', level='error')


# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
def log_event(message: str, level: str = 'info', detail: str = None):
    entry = {
        'time':    datetime.now(EST).strftime('%Y-%m-%d %H:%M:%S'),
        'message': message,
        'level':   level,
    }
    if detail:
        entry['detail'] = detail
    bot_state['trade_log'].insert(0, entry)
    bot_state['trade_log'] = bot_state['trade_log'][:LOG_MAX]
    _save_log_to_disk(bot_state['trade_log'])

    if level == 'error':
        logging.error(message)
    elif level == 'warning':
        logging.warning(message)
    else:
        logging.info(message)


# ---------------------------------------------------------------------------
# Start / stop
# ---------------------------------------------------------------------------
def start_bot() -> bool:
    if bot_state['running']:
        return False
    bot_state['running'] = True
    t = threading.Thread(target=scan_and_trade, daemon=True, name='trade-bot')
    bot_state['thread'] = t
    t.start()
    return True


def stop_bot() -> bool:
    if not bot_state['running']:
        return False
    bot_state['running'] = False
    return True
