"""
Run batch predictions for all active customers.
Writes results to churn_predictions table, including SHAP explanations.
"""
import sys
import argparse
import joblib
import pandas as pd
import numpy as np
import shap
import json
import psycopg2
from config import DB_CONFIG, DEFAULT_APP_ID

def run_predictions(app_id):
    """Score all active customers with the trained model."""
    
    print(f"Connecting to DB and loading latest model for app_id={app_id}...")
    try:
        conn = psycopg2.connect(**DB_CONFIG)
    except Exception as e:
        print(f"Database connection error: {e}", file=sys.stderr)
        return

    # 1. Load active model
    with conn.cursor() as cur:
        cur.execute("""
            SELECT model_path, model_version 
            FROM churn_model_registry 
            WHERE is_active = TRUE 
            ORDER BY trained_at DESC LIMIT 1
        """)
        row = cur.fetchone()
        if not row:
            print("No active model found. Train a model first.")
            conn.close()
            return
        
        model_path, model_version = row
    
    try:
        model_data = joblib.load(model_path)
        model = model_data['model']
        expected_features = model_data['features']
        numerical_features = model_data['numerical_features']
        categorical_features = model_data['categorical_features']
    except Exception as e:
        print(f"Failed to load model from {model_path}: {e}")
        conn.close()
        return
        
    print(f"Loaded model {model_version}. Fetching features to score...")
    
    # 2. Load features for all active customers
    query = """
    SELECT * FROM churn_feature_store 
    WHERE app_id = %s 
      AND computed_at >= NOW() - INTERVAL '2 days'
    """
    df = pd.read_sql(query, conn, params=(app_id,))
    
    if df.empty:
        print("No features found. Run feature extraction first.")
        conn.close()
        return
        
    # Extract ID columns so we can join predictions back
    ids = df[['customer_id', 'app_id']]
    
    # Process features to match training expectations
    X = df[numerical_features].fillna(0)
    
    if categorical_features:
        # Create dummies and align columns with what the model saw during training
        X = pd.get_dummies(X, columns=categorical_features)
        
        # Add missing columns (categories that weren't present in this batch)
        for col in expected_features:
            if col not in X.columns:
                X[col] = 0
                
        # Ensure exact column order
        X = X[expected_features]
    
    print(f"Scoring {len(X)} customers...")
    
    # 3. Predict probabilities
    probabilities = model.predict_proba(X)[:, 1]
    
    # 4. Generate SHAP explanations
    print("Computing SHAP explanations...")
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X)
    
    # If shap_values is a list (binary classification), take the positive class (churn)
    if isinstance(shap_values, list):
        shap_values = shap_values[1]
    
    # 5. Build prediction records
    predictions = []
    
    # Reverse dummy encoding mapping simply by aggregating impacts across dummy prefix
    # E.g. device_type_ios + device_type_android -> combined impact for device_type
    
    for idx in range(len(X)):
        prob = float(probabilities[idx])
        risk_level = get_risk_level(prob)
        
        # Calculate feature impact. Combine dummy columns back to root feature
        raw_impacts = {}
        row_values = {}
        
        shap_row = shap_values[idx]
        x_row = X.iloc[idx]
        
        for i, col in enumerate(expected_features):
            impact = shap_row[i]
            val = x_row[col]
            
            # Check if this is a dummy column (e.g. device_type_Mobile)
            is_dummy = False
            for cat_feat in categorical_features:
                if col.startswith(f"{cat_feat}_"):
                    is_dummy = True
                    # If this dummy was 1, we use its true category value and impact
                    if val == 1:
                        cat_val = col.replace(f"{cat_feat}_", "")
                        raw_impacts[cat_feat] = raw_impacts.get(cat_feat, 0) + impact
                        row_values[cat_feat] = cat_val
                    break
            
            if not is_dummy:
                raw_impacts[col] = impact
                row_values[col] = val

        # Sort and get top factors
        feature_impacts = sorted(
            [(k, v) for k, v in raw_impacts.items()],
            key=lambda item: abs(item[1]),
            reverse=True
        )[:5]
        
        top_factors = []
        for feat, impact in feature_impacts:
            val = row_values.get(feat, 0)
            top_factors.append({
                "feature": feat,
                "impact": round(float(impact), 4),
                "value": round(float(val), 2) if isinstance(val, (int, float)) and not np.isnan(val) else val,
                "direction": "increases_risk" if impact > 0 else "decreases_risk"
            })
            
        # Generate recommendations based on factors
        actions = generate_recommendations(top_factors, prob, risk_level)
        
        predictions.append((
            int(ids.iloc[idx]['customer_id']),
            int(ids.iloc[idx]['app_id']),
            prob,
            risk_level,
            None,  # predicted_churn_date placeholder
            float(min(1.0, abs(prob - 0.5) * 2)),  # confidence approximation
            json.dumps(top_factors),
            json.dumps(actions),
            model_version,
        ))

    # 6. Insert predictions
    print(f"Saving {len(predictions)} predictions to database...")
    from psycopg2.extras import execute_values
    
    try:
        with conn.cursor() as cur:
            execute_values(cur, """
                INSERT INTO churn_predictions 
                (customer_id, app_id, churn_probability, risk_level, 
                 predicted_churn_date, confidence, top_factors, 
                 recommended_actions, model_version, predicted_at)
                VALUES %s
                ON CONFLICT (customer_id, app_id) DO UPDATE SET
                    churn_probability = EXCLUDED.churn_probability,
                    risk_level = EXCLUDED.risk_level,
                    predicted_churn_date = EXCLUDED.predicted_churn_date,
                    confidence = EXCLUDED.confidence,
                    top_factors = EXCLUDED.top_factors,
                    recommended_actions = EXCLUDED.recommended_actions,
                    model_version = EXCLUDED.model_version,
                    predicted_at = EXCLUDED.predicted_at
            """, predictions, template="(%s, %s, %s, %s, %s, %s, %s::jsonb, %s::jsonb, %s, NOW())")
            conn.commit()
            
            critical = sum(1 for p in predictions if p[3] == 'critical')
            high = sum(1 for p in predictions if p[3] == 'high')
            print(f"Success! {critical} Critical Risk, {high} High Risk.")
    except Exception as e:
        print(f"Database error during insert: {e}")
    finally:
        conn.close()

def get_risk_level(probability):
    """Convert probability to risk level (adjustable thresholds)"""
    if probability >= 0.70:
        return 'critical'
    elif probability >= 0.50:
        return 'high'
    elif probability >= 0.25:
        return 'medium'
    else:
        return 'low'

def generate_recommendations(top_factors, probability, risk_level):
    """Generate actionable recommendations based on top risk factors."""
    actions = []
    factor_names = {f['feature'] for f in top_factors if f['direction'] == 'increases_risk'}
    
    # Map high-risk features to specific retention plays
    if 'failed_payment_count_30d' in factor_names or 'failed_payment_count_90d' in factor_names:
        actions.append("Trigger payment method update sequence")
    if 'days_until_expiry' in factor_names:
        actions.append("Send early renewal email with promotional code")
    if 'auto_renew_enabled' in factor_names:
        actions.append("Offer 1 month free to enable auto-renewal")
    if 'has_downgraded' in factor_names:
        actions.append("Promote features they miss from higher tier")
    if 'days_since_last_login' in factor_names:
        actions.append("Send re-engagement campaign ('We Miss You')")
    if 'customer_tenure_days' in factor_names:
        actions.append("Send 'Anniversary' or 'Milestone' appreciation gift")
        
    # Phase 2 (engagement) factors would go here
        
    # Baseline actions
    if risk_level == 'critical':
        actions.insert(0, "🚨 Priority Action: Assign to retention special team")
    elif risk_level == 'high':
        actions.insert(0, "Offer 10-20% proactive retention discount")
        
    if not actions:
        actions.append("Monitor account for further degradation")
        
    return actions[:4]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate batch predictions")
    parser.add_argument("--app-id", type=int, default=DEFAULT_APP_ID, help="App ID")
    args = parser.parse_args()
    
    if args.app_id <= 0:
        print("Error: --app-id must be positive.", file=sys.stderr)
        sys.exit(1)
        
    run_predictions(args.app_id)
