import os
import psycopg2
import pandas as pd
import lightgbm as lgb
import pickle
import json
from datetime import datetime
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score

from config import DB_CONFIG, FEATURE_COLUMNS, PHASE2_FEATURE_COLUMNS

# USER EDIT: Replace this mapping value with exactly what your system considers "Expired" in tbl_customer_subscriptions.status
EXPIRED_STATUS_CODE = 7 

def get_db_connection():
    return psycopg2.connect(**DB_CONFIG)

def train_model():
    print("Extracting features from churn_feature_store...")
    conn = get_db_connection()
    
    # Use both Phase 1 and Phase 2 features for better prediction
    features = FEATURE_COLUMNS + PHASE2_FEATURE_COLUMNS
    feature_cols_sql = ", ".join([f"f.{col}" for col in features])
    
    query = f"""
    SELECT 
        c.sms_cust_id, 
        c.status,
        {feature_cols_sql}
    FROM tbl_customer_subscriptions c
    JOIN churn_feature_store f ON c.sms_cust_id = f.sms_cust_id
    """
    
    df = pd.read_sql_query(query, conn)
    
    if len(df) < 50:
        print("Not enough data to train. Need mock execution or larger historical DB.")
        return

    # Define our Ground Truth: Did they churn? 
    # 1 = Yes (Expired), 0 = No (Active/Other)
    df['is_churned'] = (df['status'] == EXPIRED_STATUS_CODE).astype(int)
    
    # Drop rows with missing values in crucial features or fill them
    df = df.fillna(0)
    
    X = df[features]
    y = df['is_churned']
    
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
    
    print(f"Training Improved LightGBM model on {len(X_train)} samples with {len(features)} features...")
    # added is_unbalance=True to handle the ~85% churn skew
    model = lgb.LGBMClassifier(
        n_estimators=200, 
        learning_rate=0.03, 
        max_depth=7, 
        is_unbalance=True,
        random_state=42
    )
    model.fit(X_train, y_train)
    
    print(f"Evaluating model on {len(X_test)} samples...")
    y_pred = model.predict(X_test)
    y_prob = model.predict_proba(X_test)[:, 1]
    
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, zero_division=0)
    recall = recall_score(y_test, y_pred, zero_division=0)
    f1 = f1_score(y_test, y_pred, zero_division=0)
    
    if len(set(y_test)) > 1:
        auc_roc = roc_auc_score(y_test, y_prob)
    else:
        auc_roc = 0.5
        
    print(f"Improved Results - Acc: {accuracy:.4f}, Prec: {precision:.4f}, Rec: {recall:.4f}, F1: {f1:.4f}, AUC: {auc_roc:.4f}")
    
    os.makedirs('models', exist_ok=True)
    model_path = 'models/lgbm_model_v1.pkl'
    
    # Standard metadata package for predict.py
    model_data = {
        'model': model,
        'features': features,
        'numerical_features': features, # Currently all features are numerical/binary
        'categorical_features': []      # Placeholder for future cat feats
    }
    
    with open(model_path, 'wb') as f:
        pickle.dump(model_data, f)
        
    print(f"Algorithm Trained Successfully! Serialized at {model_path} with metadata.")
    
    training_samples = len(X_train)
    positive_samples = int(y_train.sum())
    negative_samples = training_samples - positive_samples
    feature_count = len(features)
    feature_importance = json.dumps(dict(zip(features, model.feature_importances_.tolist())))
    
    # Insert metadata metrics into churn_model_registry
    cursor = conn.cursor()
    registry_query = """
        INSERT INTO churn_model_registry 
        (model_version, accuracy, precision_score, recall_score, f1_score, auc_roc, training_samples, positive_samples, negative_samples, feature_count, feature_importance, is_active, trained_at)
        VALUES ('LGBM-Pipeline-v1', %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, true, %s)
        ON CONFLICT (model_version) DO UPDATE SET
            accuracy = EXCLUDED.accuracy,
            precision_score = EXCLUDED.precision_score,
            recall_score = EXCLUDED.recall_score,
            f1_score = EXCLUDED.f1_score,
            auc_roc = EXCLUDED.auc_roc,
            training_samples = EXCLUDED.training_samples,
            positive_samples = EXCLUDED.positive_samples,
            negative_samples = EXCLUDED.negative_samples,
            feature_count = EXCLUDED.feature_count,
            feature_importance = EXCLUDED.feature_importance,
            trained_at = EXCLUDED.trained_at
    """
    cursor.execute(registry_query, (
        float(accuracy), float(precision), float(recall), float(f1), float(auc_roc),
        training_samples, positive_samples, negative_samples, feature_count, feature_importance,
        datetime.now()
    ))
    conn.commit()
    
    print("Database Registry Linked.")
    cursor.close()
    conn.close()

if __name__ == "__main__":
    train_model()
