import os
import time
import threading
import logging
from datetime import datetime, timedelta

import numpy as np
import pandas as pd
import alpaca_trade_api as tradeapi
from alpaca_trade_api.rest import TimeFrame
from dotenv import load_dotenv

load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s [%(levelname)s] %(message)s',
    datefmt='%Y-%m-%d %H:%M:%S'
)

# ---------------------------------------------------------------------------
# Shared state (read/written by both the bot thread and Flask routes)
# ---------------------------------------------------------------------------
bot_state = {
    'running': False,
    'thread': None,
    'trade_log': [],          # list of log entry dicts, newest first
    'watchlist': [
        'AAPL', 'MSFT', 'GOOGL', 'AMZN', 'TSLA',
        'NVDA', 'META', 'SPY', 'QQQ', 'AMD'
    ],
    'last_scan': None,        # ISO string of last scan time
    'scan_count': 0,
}

SCAN_INTERVAL = 900   # 15 minutes in seconds
RSI_PERIOD    = 14
SMA_PERIOD    = 20
RSI_BUY       = 35
RSI_SELL      = 65
STOP_LOSS_PCT = -0.05   # -5%
MAX_POSITION_PCT = 0.20  # 20% of equity


# ---------------------------------------------------------------------------
# Alpaca API helper
# ---------------------------------------------------------------------------
def get_api() -> tradeapi.REST:
    return tradeapi.REST(
        key_id=os.getenv('APCA_API_KEY_ID'),
        secret_key=os.getenv('APCA_API_SECRET_KEY'),
        base_url=os.getenv('APCA_API_BASE_URL', 'https://paper-api.alpaca.markets'),
    )


# ---------------------------------------------------------------------------
# Indicator calculations
# ---------------------------------------------------------------------------
def calculate_rsi(closes: pd.Series, period: int = RSI_PERIOD) -> pd.Series:
    delta = closes.diff()
    gain  = delta.clip(lower=0)
    loss  = -delta.clip(upper=0)
    avg_gain = gain.rolling(window=period, min_periods=period).mean()
    avg_loss = loss.rolling(window=period, min_periods=period).mean()
    rs  = avg_gain / avg_loss.replace(0, np.nan)
    rsi = 100 - (100 / (1 + rs))
    return rsi


def calculate_sma(closes: pd.Series, period: int = SMA_PERIOD) -> pd.Series:
    return closes.rolling(window=period, min_periods=period).mean()


# ---------------------------------------------------------------------------
# Fetch bars and compute latest RSI / SMA / price
# ---------------------------------------------------------------------------
def get_signals(symbol: str):
    """
    Returns (rsi, sma, current_price) or (None, None, None) on failure.
    Needs at least SMA_PERIOD + RSI_PERIOD bars for reliable values.
    """
    api = get_api()
    end   = datetime.utcnow()
    start = end - timedelta(days=120)   # ~120 calendar days → ~85 trading days

    try:
        bars = api.get_bars(
            symbol,
            TimeFrame.Day,
            start=start.strftime('%Y-%m-%d'),
            end=end.strftime('%Y-%m-%d'),
            limit=100,
        ).df

        if bars.empty or len(bars) < SMA_PERIOD + RSI_PERIOD:
            log_event(f'Not enough bars for {symbol} (got {len(bars)})', level='warning')
            return None, None, None

        closes        = bars['close']
        rsi_series    = calculate_rsi(closes)
        sma_series    = calculate_sma(closes)
        current_rsi   = rsi_series.iloc[-1]
        current_sma   = sma_series.iloc[-1]
        current_price = closes.iloc[-1]

        if pd.isna(current_rsi) or pd.isna(current_sma):
            return None, None, None

        return current_rsi, current_sma, current_price

    except Exception as e:
        log_event(f'Failed to fetch bars for {symbol}: {e}', level='error')
        return None, None, None


# ---------------------------------------------------------------------------
# Main scan loop
# ---------------------------------------------------------------------------
def scan_and_trade():
    log_event('Bot started — scanning every 15 minutes')

    while bot_state['running']:
        try:
            _run_scan()
        except Exception as e:
            log_event(f'Unhandled scan error: {e}', level='error')

        # Sleep in 1-second ticks so stop_bot() is responsive
        for _ in range(SCAN_INTERVAL):
            if not bot_state['running']:
                break
            time.sleep(1)

    log_event('Bot loop exited')


def _run_scan():
    api = get_api()

    # Market hours check
    try:
        clock = api.get_clock()
    except Exception as e:
        log_event(f'Could not fetch market clock: {e}', level='error')
        return

    if not clock.is_open:
        log_event('Market closed — skipping scan')
        bot_state['last_scan'] = datetime.now().isoformat()
        return

    # Account info
    try:
        account = api.get_account()
    except Exception as e:
        log_event(f'Could not fetch account: {e}', level='error')
        return

    equity            = float(account.equity)
    max_position_val  = equity * MAX_POSITION_PCT

    # Current positions keyed by symbol
    try:
        positions = {p.symbol: p for p in api.list_positions()}
    except Exception as e:
        log_event(f'Could not fetch positions: {e}', level='error')
        return

    bot_state['last_scan'] = datetime.now().isoformat()
    bot_state['scan_count'] += 1
    log_event(
        f'Scan #{bot_state["scan_count"]} — equity ${equity:,.2f} | '
        f'{len(positions)} open position(s) | watching {len(bot_state["watchlist"])} symbol(s)'
    )

    # --- SELL checks ---
    for symbol, pos in list(positions.items()):
        try:
            rsi, sma, current_price = get_signals(symbol)
            if rsi is None:
                continue

            entry_price = float(pos.avg_entry_price)
            pnl_pct     = (current_price - entry_price) / entry_price

            if rsi > RSI_SELL:
                _submit_sell(api, symbol, pos.qty, f'RSI={rsi:.1f} > {RSI_SELL}', pnl_pct)
            elif pnl_pct <= STOP_LOSS_PCT:
                _submit_sell(api, symbol, pos.qty, f'stop-loss hit', pnl_pct)

        except Exception as e:
            log_event(f'Error in sell check for {symbol}: {e}', level='error')

    # --- BUY checks ---
    for symbol in bot_state['watchlist']:
        if symbol in positions:
            continue  # already holding

        try:
            rsi, sma, current_price = get_signals(symbol)
            if rsi is None:
                continue

            if rsi < RSI_BUY and current_price > sma:
                qty = int(max_position_val / current_price)
                if qty > 0:
                    _submit_buy(api, symbol, qty, rsi, sma, current_price)

        except Exception as e:
            log_event(f'Error in buy check for {symbol}: {e}', level='error')


def _submit_buy(api, symbol, qty, rsi, sma, price):
    try:
        api.submit_order(
            symbol=symbol,
            qty=qty,
            side='buy',
            type='market',
            time_in_force='day',
        )
        log_event(
            f'BUY  {symbol:6s}  qty={qty}  price=${price:.2f}  '
            f'RSI={rsi:.1f}  SMA={sma:.2f}',
            level='trade'
        )
    except Exception as e:
        log_event(f'Order failed (BUY {symbol}): {e}', level='error')


def _submit_sell(api, symbol, qty, reason, pnl_pct):
    try:
        api.submit_order(
            symbol=symbol,
            qty=qty,
            side='sell',
            type='market',
            time_in_force='day',
        )
        log_event(
            f'SELL {symbol:6s}  qty={qty}  reason={reason}  PnL={pnl_pct:+.2%}',
            level='trade'
        )
    except Exception as e:
        log_event(f'Order failed (SELL {symbol}): {e}', level='error')


# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
def log_event(message: str, level: str = 'info'):
    entry = {
        'time':    datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
        'message': message,
        'level':   level,
    }
    bot_state['trade_log'].insert(0, entry)
    bot_state['trade_log'] = bot_state['trade_log'][:500]  # keep last 500

    if level == 'error':
        logging.error(message)
    elif level == 'warning':
        logging.warning(message)
    else:
        logging.info(message)


# ---------------------------------------------------------------------------
# Start / stop
# ---------------------------------------------------------------------------
def start_bot() -> bool:
    if bot_state['running']:
        return False
    bot_state['running'] = True
    t = threading.Thread(target=scan_and_trade, daemon=True, name='trade-bot')
    bot_state['thread'] = t
    t.start()
    return True


def stop_bot() -> bool:
    if not bot_state['running']:
        return False
    bot_state['running'] = False
    return True
