import os
import psycopg2
import pandas as pd
import pickle
import json
from psycopg2.extras import execute_values
from datetime import datetime

from config import DB_CONFIG

def get_db_connection():
    return psycopg2.connect(**DB_CONFIG)

def execute_daily_predictions():
    print("Initializing Daily Batch Inference...")
    
    model_path = 'models/lgbm_model_v1.pkl'
    if not os.path.exists(model_path):
        print(f"Error: Trained model not found at {model_path}. Run train_model.py first.")
        return
        
    with open(model_path, 'rb') as f:
        model = pickle.load(f)

    conn = get_db_connection()
    
    query = """
    SELECT 
        sms_cust_id, 
        app_id, 
        days_until_expiry, 
        total_revenue
    FROM churn_feature_store
    """
    
    df = pd.read_sql_query(query, conn)
    if len(df) == 0:
        print("Empty feature store. No predictions to generate.")
        return
        
    print(f"Loaded {len(df)} active user profiles to predict.")
    
    features = ['days_until_expiry', 'total_revenue']
    X = df[features]
    
    # Run Inference
    probabilities = model.predict_proba(X)
    df['churn_probability'] = probabilities[:, 1]
    
    # Categorize Risk
    def categorize_risk(prob):
        if prob > 0.80: return 'critical'
        elif prob > 0.50: return 'high'
        elif prob > 0.30: return 'medium'
        else: return 'low'
        
    df['risk_level'] = df['churn_probability'].apply(categorize_risk)
    
    # In a fully scaled environment, we would run SHAP TreeExplainer here per user.
    # For speed, we will map standard driving logic based on numerical features to JSON.
    def get_top_factor(row):
        factors = {}
        if row['days_until_expiry'] < 7:
            factors["Approaching Expiration"] = round(float(row['churn_probability']) * 0.6, 3)
        if row['total_revenue'] < 20:
            factors["Low Revenue Engagement"] = round(float(row['churn_probability']) * 0.4, 3)
            
        if not factors:
            factors["Baseline Behavior"] = round(float(row['churn_probability']), 3)
            
        return json.dumps(factors)
        
    df['top_factors'] = df.apply(get_top_factor, axis=1)
    df['predicted_at'] = datetime.now()
    df['model_version'] = 'LGBM-Pipeline-v1'
    
    # Wipe the existing prediction snapshot to keep UI speedy, or UPSERT.
    cursor = conn.cursor()
    cursor.execute("TRUNCATE TABLE churn_predictions CASCADE;")
    
    insert_query = """
        INSERT INTO churn_predictions 
        (sms_cust_id, app_id, churn_probability, risk_level, top_factors, model_version, predicted_at)
        VALUES %s
    """
    
    records = df[['sms_cust_id', 'app_id', 'churn_probability', 'risk_level', 'top_factors', 'model_version', 'predicted_at']].to_records(index=False)
    
    execute_values(cursor, insert_query, records.tolist())
    conn.commit()
    
    print(f"Inference Completed! Successfully wrote {len(records)} prediction scores to DB.")
    
    cursor.close()
    conn.close()

if __name__ == "__main__":
    execute_daily_predictions()
