Skip to content

Commit c8e1e2a

Browse files
committed
add unit test for prefix cache tracking scorer
Signed-off-by: zhengkezhou1 <[email protected]>
1 parent 673ddd9 commit c8e1e2a

File tree

2 files changed

+155
-3
lines changed

2 files changed

+155
-3
lines changed

pkg/plugins/scorer/prefix_cache_tracking.go

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ import (
1616
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
1717
)
1818

19+
// PodScorer defines the interface for scoring pods based on KV-cache state.
20+
type PodScorer interface {
21+
GetPodScores(ctx context.Context, prompt, model string, pods []string) (map[string]int, error)
22+
}
23+
1924
// PrefixCacheTrackingConfig holds the configuration for the
2025
// PrefixCacheTrackingScorer.
2126
type PrefixCacheTrackingConfig struct {
@@ -84,9 +89,19 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra
8489
return &PrefixCacheTrackingScorer{
8590
typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType},
8691
kvCacheIndexer: kvCacheIndexer,
92+
podScorer: kvCacheIndexer,
8793
}, nil
8894
}
8995

96+
// NewWithPodScorer creates a new PrefixCacheTrackingScorer with a custom PodScorer.
97+
// This is mainly used for testing to inject mock dependencies.
98+
func NewWithPodScorer(podScorer PodScorer) *PrefixCacheTrackingScorer {
99+
return &PrefixCacheTrackingScorer{
100+
typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType},
101+
podScorer: podScorer,
102+
}
103+
}
104+
90105
// PrefixCacheTrackingScorer implements the framework.Scorer interface.
91106
// The scorer implements the `cache_tracking` mode of the prefix cache plugin.
92107
// It uses the `kvcache.Indexer` to score pods based on the KV-cache index
@@ -95,6 +110,7 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra
95110
type PrefixCacheTrackingScorer struct {
96111
typedName plugins.TypedName
97112
kvCacheIndexer *kvcache.Indexer
113+
podScorer PodScorer
98114
}
99115

100116
// TypedName returns the typed name of the plugin.
@@ -114,13 +130,13 @@ func (s *PrefixCacheTrackingScorer) Score(ctx context.Context, _ *types.CycleSta
114130
loggerDebug := log.FromContext(ctx).WithName(s.typedName.String()).V(logutil.DEBUG)
115131
if request == nil {
116132
loggerDebug.Info("Request is nil, skipping scoring")
117-
return nil
133+
return make(map[types.Pod]float64)
118134
}
119135

120-
scores, err := s.kvCacheIndexer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil)
136+
scores, err := s.podScorer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil)
121137
if err != nil {
122138
loggerDebug.Error(err, "Failed to get pod scores")
123-
return nil
139+
return make(map[types.Pod]float64)
124140
}
125141
loggerDebug.Info("Got pod scores", "scores", scores)
126142

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package scorer_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
7+
"testing"
8+
9+
"github.com/google/go-cmp/cmp"
10+
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
11+
"github.com/stretchr/testify/require"
12+
k8stypes "k8s.io/apimachinery/pkg/types"
13+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
14+
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
15+
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
16+
)
17+
18+
// mockPodScorer is a mock implementation of the scorer.PodScorer interface for testing.
19+
type mockPodScorer struct {
20+
scores map[string]int
21+
err error
22+
}
23+
24+
func (m *mockPodScorer) GetPodScores(_ context.Context, _, _ string, _ []string) (map[string]int, error) {
25+
if m.err != nil {
26+
return nil, m.err
27+
}
28+
return m.scores, nil
29+
}
30+
31+
func TestPrefixCacheTracking_Score(t *testing.T) {
32+
testcases := []struct {
33+
name string
34+
pods []types.Pod
35+
request *types.LLMRequest
36+
mockScores map[string]int
37+
mockError error
38+
wantScoresByAddress map[string]float64 // Use address as key instead of Pod objects
39+
}{
40+
{
41+
name: "test normalized scores",
42+
pods: []types.Pod{
43+
&types.PodMetrics{
44+
Pod: &backend.Pod{
45+
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
46+
Address: "10.0.0.1:8080",
47+
},
48+
MetricsState: &backendmetrics.MetricsState{
49+
WaitingQueueSize: 0,
50+
},
51+
},
52+
&types.PodMetrics{
53+
Pod: &backend.Pod{
54+
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
55+
Address: "10.0.0.2:8080",
56+
},
57+
MetricsState: &backendmetrics.MetricsState{
58+
WaitingQueueSize: 1,
59+
},
60+
},
61+
&types.PodMetrics{
62+
Pod: &backend.Pod{
63+
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
64+
Address: "10.0.0.3:8080",
65+
},
66+
MetricsState: &backendmetrics.MetricsState{
67+
WaitingQueueSize: 2,
68+
},
69+
},
70+
},
71+
request: &types.LLMRequest{
72+
TargetModel: "gpt-4",
73+
Prompt: "what is meaning of life?",
74+
},
75+
mockScores: map[string]int{
76+
"10.0.0.1:8080": 10,
77+
"10.0.0.2:8080": 20,
78+
"10.0.0.3:8080": 30,
79+
},
80+
wantScoresByAddress: map[string]float64{
81+
"10.0.0.1:8080": 0.0, // (10-10)/(30-10) = 0.0
82+
"10.0.0.2:8080": 0.5, // (20-10)/(30-10) = 0.5
83+
"10.0.0.3:8080": 1.0, // (30-10)/(30-10) = 1.0
84+
},
85+
},
86+
{
87+
name: "test nil request",
88+
pods: []types.Pod{
89+
&types.PodMetrics{
90+
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
91+
},
92+
},
93+
request: nil,
94+
wantScoresByAddress: make(map[string]float64), // empty map instead of nil
95+
},
96+
{
97+
name: "test pod scorer error",
98+
pods: []types.Pod{
99+
&types.PodMetrics{
100+
Pod: &backend.Pod{
101+
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
102+
Address: "10.0.0.1:8080",
103+
},
104+
},
105+
},
106+
request: &types.LLMRequest{
107+
TargetModel: "gpt-4",
108+
Prompt: "test prompt",
109+
},
110+
mockError: errors.New("test error"),
111+
wantScoresByAddress: make(map[string]float64), // empty map instead of nil
112+
},
113+
}
114+
115+
for _, tt := range testcases {
116+
t.Run(tt.name, func(t *testing.T) {
117+
mockScorer := &mockPodScorer{
118+
scores: tt.mockScores,
119+
err: tt.mockError,
120+
}
121+
prefixCacheScorer := scorer.NewWithPodScorer(mockScorer)
122+
require.NotNil(t, prefixCacheScorer)
123+
got := prefixCacheScorer.Score(context.Background(), nil, tt.request, tt.pods)
124+
// Convert the result to address-based map for easier comparison
125+
gotByAddress := make(map[string]float64)
126+
for pod, score := range got {
127+
if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil {
128+
gotByAddress[podMetrics.GetPod().Address] = score
129+
}
130+
}
131+
if diff := cmp.Diff(tt.wantScoresByAddress, gotByAddress); diff != "" {
132+
t.Errorf("Unexpected output (-want +got): %v", diff)
133+
}
134+
})
135+
}
136+
}

0 commit comments

Comments
 (0)