package drift import ( "testing" "git.agentview.dev/profit/golangLAKEHOUSE/internal/distillation" ) func mkInput(sourceFile string, persisted distillation.ScoreCategory, succ []string) ScorerDriftInput { return ScorerDriftInput{ Record: distillation.EvidenceRecord{ RunID: "run-x", TaskID: "task-x", Timestamp: "2026-01-01T00:00:00Z", SchemaVersion: distillation.EvidenceSchemaVersion, Provenance: distillation.Provenance{ SourceFile: sourceFile, SigHash: "abc", RecordedAt: "2026-01-01T00:00:01Z", }, SuccessMarkers: succ, }, PersistedCategory: persisted, } } func TestComputeScorerDrift_NoDrift(t *testing.T) { // All inputs have persisted=accepted matching what the current // scrum_review scorer produces on accepted_on_attempt_1. inputs := []ScorerDriftInput{ mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_1"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_1"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_1"}), } r := ComputeScorerDrift(inputs, true) if r.TotalChecked != 3 || r.Matched != 3 || r.Drifted != 0 { t.Errorf("no-drift case: total=%d matched=%d drifted=%d", r.TotalChecked, r.Matched, r.Drifted) } if r.DriftRate != 0 { t.Errorf("drift_rate: want 0, got %v", r.DriftRate) } if len(r.Entries) != 0 { t.Errorf("entries: want 0, got %d", len(r.Entries)) } } func TestComputeScorerDrift_ShiftDetected(t *testing.T) { // Simulate a historical labeling where the persisted scorer // thought attempt-2 acceptances were "accepted" but the current // scorer (this code) categorizes them as "partially_accepted". // Drift should fire on those. inputs := []ScorerDriftInput{ // Match: attempt 1 → accepted (still) mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_1"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_1"}), // Drift: persisted thought attempt-2 was accepted, today's scorer says partial mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_2"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_3"}), // Drift: persisted thought attempt-5 was accepted, today's scorer says partial (high-cost) mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_5"}), } r := ComputeScorerDrift(inputs, true) if r.TotalChecked != 5 { t.Errorf("total: want 5, got %d", r.TotalChecked) } if r.Matched != 2 { t.Errorf("matched: want 2, got %d", r.Matched) } if r.Drifted != 3 { t.Errorf("drifted: want 3, got %d", r.Drifted) } wantRate := 3.0 / 5.0 if r.DriftRate < wantRate-1e-9 || r.DriftRate > wantRate+1e-9 { t.Errorf("drift_rate: want %v, got %v", wantRate, r.DriftRate) } if len(r.Entries) != 3 { t.Errorf("entries: want 3 mismatches, got %d", len(r.Entries)) } // Shift matrix should show one shift: accepted → partially_accepted, count=3 if len(r.ShiftMatrix) != 1 { t.Errorf("shift matrix: want 1 shift, got %d (%+v)", len(r.ShiftMatrix), r.ShiftMatrix) } else { s := r.ShiftMatrix[0] if s.From != distillation.CategoryAccepted || s.To != distillation.CategoryPartiallyAccepted || s.Count != 3 { t.Errorf("shift: got %+v", s) } } } func TestComputeScorerDrift_MultipleShiftsSortedByCount(t *testing.T) { inputs := []ScorerDriftInput{ // 3× accepted→partial mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_2"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_2"}), mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_2"}), // 1× rejected→needs_human (no marker) { Record: distillation.EvidenceRecord{ RunID: "r1", TaskID: "t1", Timestamp: "2026-01-01T00:00:00Z", SchemaVersion: distillation.EvidenceSchemaVersion, Provenance: distillation.Provenance{ SourceFile: "data/_kb/scrum_reviews.jsonl", SigHash: "x", RecordedAt: "2026-01-01T00:00:01Z", }, // no markers → needs_human_review }, PersistedCategory: distillation.CategoryRejected, }, } r := ComputeScorerDrift(inputs, false) if r.Drifted != 4 { t.Errorf("drifted: want 4, got %d", r.Drifted) } if len(r.ShiftMatrix) != 2 { t.Errorf("shift matrix: want 2 distinct shifts, got %d", len(r.ShiftMatrix)) } // Sorted by count desc, so accepted→partial (3) before rejected→needs_human (1) if r.ShiftMatrix[0].Count != 3 || r.ShiftMatrix[1].Count != 1 { t.Errorf("shift order wrong: got %+v", r.ShiftMatrix) } } func TestComputeScorerDrift_IncludeEntriesFalse(t *testing.T) { inputs := []ScorerDriftInput{ mkInput("data/_kb/scrum_reviews.jsonl", distillation.CategoryAccepted, []string{"accepted_on_attempt_2"}), } r := ComputeScorerDrift(inputs, false) if r.Drifted != 1 { t.Errorf("drifted: want 1, got %d", r.Drifted) } if len(r.Entries) != 0 { t.Errorf("entries: want 0 when includeEntries=false, got %d", len(r.Entries)) } } func TestComputeScorerDrift_EmptyInput(t *testing.T) { r := ComputeScorerDrift(nil, true) if r.TotalChecked != 0 || r.Drifted != 0 || r.Matched != 0 { t.Errorf("empty: want all-zero, got %+v", r) } if r.DriftRate != 0 { t.Errorf("drift_rate on empty: want 0, got %v", r.DriftRate) } } func TestComputeScorerDrift_ScorerVersionStamped(t *testing.T) { r := ComputeScorerDrift(nil, false) if r.ScorerVersion != distillation.ScorerVersion { t.Errorf("scorer_version: want %q, got %q", distillation.ScorerVersion, r.ScorerVersion) } } // ── Distribution drift (PSI) tests ──────────────────────────────── // TestDistributionDrift_IdenticalIsStable: same data on both sides // should yield PSI ≈ 0 and tier=stable. Anchors the lower bound. func TestDistributionDrift_IdenticalIsStable(t *testing.T) { data := []float64{0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9} r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: data, Current: data, NumBuckets: 5, }) if r.PSI > 0.001 { t.Errorf("identical distributions: expected PSI ≈ 0, got %f", r.PSI) } if r.Tier != DriftTierStable { t.Errorf("expected stable tier, got %q", r.Tier) } } // TestDistributionDrift_HardShiftIsMajor: distribution moved // completely to a different range — should yield major tier. func TestDistributionDrift_HardShiftIsMajor(t *testing.T) { baseline := []float64{0.1, 0.1, 0.2, 0.2, 0.3, 0.3, 0.4, 0.4} current := []float64{0.7, 0.7, 0.8, 0.8, 0.9, 0.9, 1.0, 1.0} r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: baseline, Current: current, NumBuckets: 10, }) if r.PSI < 0.25 { t.Errorf("hard distribution shift: expected PSI ≥ 0.25, got %f", r.PSI) } if r.Tier != DriftTierMajor { t.Errorf("expected major tier, got %q", r.Tier) } } // TestDistributionDrift_DetectsModerateShift: distribution shifted // noticeably but not catastrophically — PSI must be > 0 (some drift // detected) and tier must NOT be stable. Whether the tier is minor // vs major depends on bucketing granularity; we don't pin that here // because PSI thresholds are sensitive to bucket count. func TestDistributionDrift_DetectsModerateShift(t *testing.T) { // Baseline: many around 0.5, some spread. baseline := []float64{0.4, 0.45, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.6, 0.45, 0.5, 0.5, 0.55, 0.5, 0.5, 0.5, 0.55, 0.5, 0.6} // Current: same range, slight rightward shift (still overlapping). current := []float64{0.45, 0.5, 0.5, 0.55, 0.55, 0.55, 0.6, 0.6, 0.65, 0.7, 0.5, 0.55, 0.55, 0.55, 0.6, 0.6, 0.55, 0.6, 0.65, 0.65} r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: baseline, Current: current, NumBuckets: 10, }) if r.PSI < 0.01 { t.Errorf("moderate shift should produce PSI > 0.01, got %f", r.PSI) } if r.Tier == DriftTierStable { t.Errorf("moderate shift should NOT be stable tier, got PSI=%f tier=%q", r.PSI, r.Tier) } } // TestDistributionDrift_EmptyInputs: empty baseline OR current // returns PSI=0, stable tier — caller must check N before trusting. func TestDistributionDrift_EmptyInputs(t *testing.T) { r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: []float64{}, Current: []float64{1, 2, 3}, }) if r.PSI != 0 || r.Tier != DriftTierStable { t.Errorf("empty baseline: expected PSI=0 stable, got psi=%f tier=%q", r.PSI, r.Tier) } r = ComputeDistributionDrift(DistributionDriftInput{ Baseline: []float64{1, 2, 3}, Current: []float64{}, }) if r.PSI != 0 || r.Tier != DriftTierStable { t.Errorf("empty current: expected PSI=0 stable, got psi=%f tier=%q", r.PSI, r.Tier) } } // TestDistributionDrift_AllIdenticalValues: degenerate case where // everything's the same value (e.g., all zeros). Should not panic; // returns stable. func TestDistributionDrift_AllIdenticalValues(t *testing.T) { r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: []float64{0.5, 0.5, 0.5}, Current: []float64{0.5, 0.5, 0.5}, }) if r.Tier != DriftTierStable { t.Errorf("expected stable on identical-singleton, got %q", r.Tier) } } // TestDistributionDrift_BucketCounts: per-bucket counts must sum to // the input N. If they don't, we're losing observations to bucket // boundary issues. func TestDistributionDrift_BucketCounts(t *testing.T) { baseline := []float64{0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0} current := []float64{0.5, 0.5, 0.5, 0.5} r := ComputeDistributionDrift(DistributionDriftInput{ Baseline: baseline, Current: current, NumBuckets: 5, }) totalB := 0 totalC := 0 for _, b := range r.Buckets { totalB += b.BaselineCount totalC += b.CurrentCount } if totalB != len(baseline) { t.Errorf("baseline bucket counts sum to %d, expected %d", totalB, len(baseline)) } if totalC != len(current) { t.Errorf("current bucket counts sum to %d, expected %d", totalC, len(current)) } } // TestDistributionDrift_NumBucketsClamping: 0 → default 10; > 100 → 100. func TestDistributionDrift_NumBucketsClamping(t *testing.T) { in := DistributionDriftInput{Baseline: []float64{1, 2}, Current: []float64{1, 2}} in.NumBuckets = 0 r := ComputeDistributionDrift(in) if r.NumBuckets != 10 { t.Errorf("0 should default to 10 buckets, got %d", r.NumBuckets) } in.NumBuckets = 500 r = ComputeDistributionDrift(in) if r.NumBuckets != 100 { t.Errorf("500 should clamp to 100 buckets, got %d", r.NumBuckets) } }