"""
Generate churn labels from historical data.
For each customer observation point, determine if they churned within 30 days.

Label = 1 (churned): Customer had status 3 or 7 within 30 days of observation
Label = 0 (retained): Customer remained active (status 2) for 30+ days after observation
"""

import sys
import argparse
import psycopg2
from config import DB_CONFIG, DEFAULT_APP_ID


def generate_labels(app_id, lookback_months=6):
    """Generate training labels from historical subscription data."""
    
    query = """
    WITH customer_events AS (
        -- Get all subscription status changes
        SELECT 
            sms_cust_id,
            app_id,
            status,
            start_date,
            end_date,
            created_at,
            ROW_NUMBER() OVER (PARTITION BY sms_cust_id ORDER BY created_at) AS event_seq
        FROM tbl_customer_subscriptions
        WHERE app_id = %(app_id)s
          AND created_at >= NOW() - INTERVAL '%(months)s months'
    ),
    observation_points AS (
        -- Create monthly observation points for each customer
        SELECT DISTINCT
            ce.sms_cust_id,
            ce.app_id,
            date_trunc('month', gs.observation_month)::DATE AS observation_date
        FROM customer_events ce
        CROSS JOIN generate_series(
            (NOW() - INTERVAL '%(months)s months')::DATE,
            (NOW() - INTERVAL '30 days')::DATE,
            INTERVAL '1 month'
        ) AS gs(observation_month)
        WHERE ce.created_at <= gs.observation_month
    ),
    churn_labels AS (
        SELECT 
            op.sms_cust_id,
            op.app_id,
            op.observation_date,
            -- Did this customer have a churn event within 30 days of observation?
            CASE 
                WHEN EXISTS (
                    SELECT 1 
                    FROM tbl_customer_subscriptions cs
                    WHERE cs.sms_cust_id = op.sms_cust_id
                      AND cs.app_id = op.app_id
                      AND cs.status IN (3, 7)
                      AND cs.end_date BETWEEN op.observation_date 
                                          AND op.observation_date + INTERVAL '30 days'
                ) THEN TRUE
                ELSE FALSE
            END AS churned,
            -- Actual churn date if churned
            (
                SELECT MIN(cs.end_date)::DATE
                FROM tbl_customer_subscriptions cs
                WHERE cs.sms_cust_id = op.sms_cust_id
                  AND cs.app_id = op.app_id
                  AND cs.status IN (3, 7)
                  AND cs.end_date BETWEEN op.observation_date 
                                      AND op.observation_date + INTERVAL '30 days'
            ) AS churn_date
        FROM observation_points op
    )
    INSERT INTO churn_training_labels (sms_cust_id, app_id, churned, observation_date, churn_date)
    SELECT sms_cust_id, app_id, churned, observation_date, churn_date
    FROM churn_labels
    ON CONFLICT (sms_cust_id, app_id, observation_date) DO UPDATE 
    SET churned = EXCLUDED.churned,
        churn_date = EXCLUDED.churn_date;
    """
    
    try:
        print(f"Connecting to DB and generating labels for app_id={app_id} (lookback: {lookback_months}m)...")
        with psycopg2.connect(**DB_CONFIG) as conn:
            with conn.cursor() as cur:
                # Get count before
                cur.execute("SELECT count(*) FROM churn_training_labels WHERE app_id = %s", (app_id,))
                count_before = cur.fetchone()[0]
                
                # Execute insert
                cur.execute(query, {'app_id': app_id, 'months': lookback_months})
                conn.commit()
                
                # Get count after
                cur.execute("SELECT count(*) FROM churn_training_labels WHERE app_id = %s", (app_id,))
                count_after = cur.fetchone()[0]
                
                print(f"Success! Labels generated. Added {count_after - count_before} new observation points.")
                print(f"Total labels for app_id={app_id}: {count_after}")
                
    except Exception as e:
        print(f"Database error: {e}", file=sys.stderr)
        sys.exit(1)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate churn labels for training.")
    parser.add_argument("--app-id", type=int, default=DEFAULT_APP_ID, help="App ID to process")
    parser.add_argument("--months", type=int, default=6, help="Months of history to look back")
    args = parser.parse_args()
    
    if args.app_id <= 0:
        print("Error: --app-id must be a positive integer.", file=sys.stderr)
        sys.exit(1)
        
    generate_labels(args.app_id, args.months)
