Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions pkg/plugins/scorer/prefix_cache_tracking.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ import (
logutil "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging"
)

// PodScorer defines the interface for scoring pods based on KV-cache state.
type PodScorer interface {
GetPodScores(ctx context.Context, prompt, model string, pods []string) (map[string]int, error)
}

// PrefixCacheTrackingConfig holds the configuration for the
// PrefixCacheTrackingScorer.
type PrefixCacheTrackingConfig struct {
Expand Down Expand Up @@ -84,9 +89,19 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra
return &PrefixCacheTrackingScorer{
typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType},
kvCacheIndexer: kvCacheIndexer,
podScorer: kvCacheIndexer,
}, nil
}

// NewWithPodScorer creates a new PrefixCacheTrackingScorer with a custom PodScorer.
// This is mainly used for testing to inject mock dependencies.
func NewWithPodScorer(podScorer PodScorer) *PrefixCacheTrackingScorer {
return &PrefixCacheTrackingScorer{
typedName: plugins.TypedName{Type: prefix.PrefixCachePluginType},
podScorer: podScorer,
}
}

// PrefixCacheTrackingScorer implements the framework.Scorer interface.
// The scorer implements the `cache_tracking` mode of the prefix cache plugin.
// It uses the `kvcache.Indexer` to score pods based on the KV-cache index
Expand All @@ -95,6 +110,7 @@ func New(ctx context.Context, config PrefixCacheTrackingConfig) (*PrefixCacheTra
type PrefixCacheTrackingScorer struct {
typedName plugins.TypedName
kvCacheIndexer *kvcache.Indexer
podScorer PodScorer
}

// TypedName returns the typed name of the plugin.
Expand All @@ -114,13 +130,13 @@ func (s *PrefixCacheTrackingScorer) Score(ctx context.Context, _ *types.CycleSta
loggerDebug := log.FromContext(ctx).WithName(s.typedName.String()).V(logutil.DEBUG)
if request == nil {
loggerDebug.Info("Request is nil, skipping scoring")
return nil
return make(map[types.Pod]float64)
}

scores, err := s.kvCacheIndexer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil)
scores, err := s.podScorer.GetPodScores(ctx, request.Prompt, request.TargetModel, nil)
if err != nil {
loggerDebug.Error(err, "Failed to get pod scores")
return nil
return make(map[types.Pod]float64)
}
loggerDebug.Info("Got pod scores", "scores", scores)

Expand Down
136 changes: 136 additions & 0 deletions pkg/plugins/scorer/prefix_cache_tracking_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
package scorer_test

import (
"context"
"errors"

"testing"

"github.com/google/go-cmp/cmp"
"github.com/llm-d/llm-d-inference-scheduler/pkg/plugins/scorer"
"github.com/stretchr/testify/require"
k8stypes "k8s.io/apimachinery/pkg/types"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend"
backendmetrics "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics"
"sigs.k8s.io/gateway-api-inference-extension/pkg/epp/scheduling/types"
)

// mockPodScorer is a mock implementation of the scorer.PodScorer interface for testing.
type mockPodScorer struct {
scores map[string]int
err error
}

func (m *mockPodScorer) GetPodScores(_ context.Context, _, _ string, _ []string) (map[string]int, error) {
if m.err != nil {
return nil, m.err
}
return m.scores, nil
}

func TestPrefixCacheTracking_Score(t *testing.T) {
testcases := []struct {
name string
pods []types.Pod
request *types.LLMRequest
mockScores map[string]int
mockError error
wantScoresByAddress map[string]float64 // Use address as key instead of Pod objects
}{
{
name: "test normalized scores",
pods: []types.Pod{
&types.PodMetrics{
Pod: &backend.Pod{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
MetricsState: &backendmetrics.MetricsState{
WaitingQueueSize: 0,
},
},
&types.PodMetrics{
Pod: &backend.Pod{
NamespacedName: k8stypes.NamespacedName{Name: "pod-b"},
Address: "10.0.0.2:8080",
},
MetricsState: &backendmetrics.MetricsState{
WaitingQueueSize: 1,
},
},
&types.PodMetrics{
Pod: &backend.Pod{
NamespacedName: k8stypes.NamespacedName{Name: "pod-c"},
Address: "10.0.0.3:8080",
},
MetricsState: &backendmetrics.MetricsState{
WaitingQueueSize: 2,
},
},
},
request: &types.LLMRequest{
TargetModel: "gpt-4",
Prompt: "what is meaning of life?",
},
mockScores: map[string]int{
"10.0.0.1:8080": 10,
"10.0.0.2:8080": 20,
"10.0.0.3:8080": 30,
},
wantScoresByAddress: map[string]float64{
"10.0.0.1:8080": 0.0, // (10-10)/(30-10) = 0.0
"10.0.0.2:8080": 0.5, // (20-10)/(30-10) = 0.5
"10.0.0.3:8080": 1.0, // (30-10)/(30-10) = 1.0
},
},
{
name: "test nil request",
pods: []types.Pod{
&types.PodMetrics{
Pod: &backend.Pod{NamespacedName: k8stypes.NamespacedName{Name: "pod-a"}},
},
},
request: nil,
wantScoresByAddress: make(map[string]float64), // empty map instead of nil
},
{
name: "test pod scorer error",
pods: []types.Pod{
&types.PodMetrics{
Pod: &backend.Pod{
NamespacedName: k8stypes.NamespacedName{Name: "pod-a"},
Address: "10.0.0.1:8080",
},
},
},
request: &types.LLMRequest{
TargetModel: "gpt-4",
Prompt: "test prompt",
},
mockError: errors.New("test error"),
wantScoresByAddress: make(map[string]float64), // empty map instead of nil
},
}

for _, tt := range testcases {
t.Run(tt.name, func(t *testing.T) {
mockScorer := &mockPodScorer{
Copy link
Member

@vMaroon vMaroon Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The precise prefix-cache scorer is a wrapper of the kvcache.Indexer. The wrapping logic is mostly normalization.

What this test ends up doing is test the latter - which is great, but can be achieved by testing the indexedScoresToNormalizedScoredPods function directly instead of using a mock PodScorer interface.

I think this test should verify correctness of the scorer as-is. This would require:

  1. Access to the indexer's backing kvblock.Index
  2. Having the test populate the kvblock.Index with kv-block information
  3. Testing correctness of the scorer as-is

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding data to the kvblock.Index alone isn't sufficient; we also need to add data to the tokensIndexer. However, the kv-cache-manager currently only allows access to the KVBlockIndex. Can we add a method to the kv-cache-manager to expose the tokensIndexer?

Copy link
Member

@vMaroon vMaroon Aug 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can add data to the tokenization cache by making the calls in a "warmup" step (call once, get no scores). Note that this behavior will change soon once tokenization is synchronous, after-which there would be no need for such a step.

scores: tt.mockScores,
err: tt.mockError,
}
prefixCacheScorer := scorer.NewWithPodScorer(mockScorer)
require.NotNil(t, prefixCacheScorer)
got := prefixCacheScorer.Score(context.Background(), nil, tt.request, tt.pods)
// Convert the result to address-based map for easier comparison
gotByAddress := make(map[string]float64)
for pod, score := range got {
if podMetrics, ok := pod.(*types.PodMetrics); ok && podMetrics.GetPod() != nil {
gotByAddress[podMetrics.GetPod().Address] = score
}
}
if diff := cmp.Diff(tt.wantScoresByAddress, gotByAddress); diff != "" {
t.Errorf("Unexpected output (-want +got): %v", diff)
}
})
}
}
Loading