Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 106 additions & 3 deletions gateway/domain_circuit_breaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,40 @@ import (
"github.com/pokt-network/poktroll/pkg/polylog"
"github.com/redis/go-redis/v9"

"github.com/pokt-network/path/metrics"
"github.com/pokt-network/path/protocol"
)

const defaultMaxTTL = 30 * time.Minute

// classifyCircuitBreakReason maps the free-text reason string passed to
// MarkBroken into a bounded category, suitable for use as a Prometheus label.
// The raw reason contains response snippets, error messages, and status codes
// (high cardinality) — for metrics we only need the prefix bucket.
//
// Keep in sync with the metrics.CircuitBreakReason* constants and with the
// reason strings produced by MarkBroken call sites in
// gateway/http_request_context_handle_request.go.
func classifyCircuitBreakReason(reason string) string {
// Match by prefix; reasons are typically formatted as "<category>: <details>"
// or "<category>_<detail>: ...". Order matters for prefixes that share leading
// substrings (parallel_retry vs retry).
switch {
case strings.HasPrefix(reason, "parallel_retry"):
return metrics.CircuitBreakReasonParallelRetry
case strings.HasPrefix(reason, "batch_transport"):
return metrics.CircuitBreakReasonBatchTransport
case strings.HasPrefix(reason, "batch_heuristic"):
return metrics.CircuitBreakReasonBatchHeuristic
case strings.HasPrefix(reason, "retry"):
return metrics.CircuitBreakReasonRetry
case strings.HasPrefix(reason, "heuristic"):
return metrics.CircuitBreakReasonHeuristic
default:
return metrics.CircuitBreakReasonUnknown
}
}

// DomainCircuitBreaker tracks broken domains across pods via Redis.
// When any pod discovers a domain is returning errors, it marks it broken
// so all pods skip that domain on initial attempts for the TTL window.
Expand Down Expand Up @@ -112,6 +141,17 @@ func (cb *DomainCircuitBreaker) MarkBroken(ctx context.Context, serviceID, domai
entry.domains[domain] = brokenDomainState{expiry: expiry, hitCount: hitCount, reason: reason}
cb.mu.Unlock()

// Surface state via Prometheus gauge so dashboards can show "currently broken
// domains" without needing Redis access. Idempotent — re-setting to 1 is fine.
metrics.SetCircuitBreakerState(serviceID, domain, true)

// Record a "broken" event with reason category so dashboards can show rate
// of breaks decomposed by cause (retry / batch_transport / batch_heuristic /
// parallel_retry / heuristic / unknown). Counter increments on every call
// — including hit-count escalations — which is the right semantic: each
// MarkBroken is a discrete trigger event.
metrics.RecordCircuitBreakerEvent(serviceID, domain, classifyCircuitBreakReason(reason), metrics.CircuitBreakerEventBroken)

// Always log circuit break events at error level for production visibility.
// This is critical for diagnosing why domains get locked out.
cb.logger.Error().
Expand Down Expand Up @@ -172,17 +212,24 @@ func filterExpiredDomains(domains map[string]brokenDomainState) map[string]bool
// Used when Redis is not available (local-only mode).
func (cb *DomainCircuitBreaker) refreshLocal(serviceID string) map[string]bool {
cb.mu.Lock()
defer cb.mu.Unlock()

entry, ok := cb.cache[serviceID]
if !ok {
cb.mu.Unlock()
return nil
}

// Remove expired entries
// Remove expired entries; collect their names + reasons so we can drop the
// gauge to 0 and record a "recovered" event outside the lock.
now := time.Now()
type expiredEntry struct {
domain string
reason string
}
var expired []expiredEntry
for domain, state := range entry.domains {
if !state.expiry.After(now) {
expired = append(expired, expiredEntry{domain: domain, reason: state.reason})
delete(entry.domains, domain)
}
}
Expand All @@ -192,6 +239,12 @@ func (cb *DomainCircuitBreaker) refreshLocal(serviceID string) map[string]bool {
for d := range entry.domains {
result[d] = true
}
cb.mu.Unlock()

for _, e := range expired {
metrics.SetCircuitBreakerState(serviceID, e.domain, false)
metrics.RecordCircuitBreakerEvent(serviceID, e.domain, classifyCircuitBreakReason(e.reason), metrics.CircuitBreakerEventRecovered)
}
return result
}

Expand Down Expand Up @@ -264,15 +317,24 @@ func (cb *DomainCircuitBreaker) refreshFromRedis(ctx context.Context, serviceID
cb.redisClient.HDel(ctx, key, expiredFields...)
}

// Merge with local cache (local entries might be newer than Redis)
// Merge with local cache (local entries might be newer than Redis).
// Capture pre-merge cache contents so we can drop the gauge to 0 for any
// domain that was previously broken but is no longer present after merge.
cb.mu.Lock()
type expiredLocal struct {
domain string
reason string
}
var previouslyBroken []expiredLocal
existingEntry, ok := cb.cache[serviceID]
if ok {
for domain, state := range existingEntry.domains {
if state.expiry.After(now) {
if redisState, exists := domains[domain]; !exists || state.expiry.After(redisState.expiry) {
domains[domain] = state
}
} else {
previouslyBroken = append(previouslyBroken, expiredLocal{domain: domain, reason: state.reason})
}
}
}
Expand All @@ -282,6 +344,29 @@ func (cb *DomainCircuitBreaker) refreshFromRedis(ctx context.Context, serviceID
}
cb.mu.Unlock()

// Drop gauge for previously-broken-but-now-expired domains, plus the explicit
// expiredFields list we already collected from Redis. Done outside the lock.
// We don't have a reason for expiredFields entries (parsed Redis value gave us
// one, but for compactness we don't carry it through here) — they get the
// "unknown" reason bucket on recovery, which is acceptable.
for _, d := range expiredFields {
metrics.SetCircuitBreakerState(serviceID, d, false)
metrics.RecordCircuitBreakerEvent(serviceID, d, metrics.CircuitBreakReasonUnknown, metrics.CircuitBreakerEventRecovered)
}
for _, e := range previouslyBroken {
metrics.SetCircuitBreakerState(serviceID, e.domain, false)
metrics.RecordCircuitBreakerEvent(serviceID, e.domain, classifyCircuitBreakReason(e.reason), metrics.CircuitBreakerEventRecovered)
}
// Re-assert gauge=1 for currently-broken domains. Idempotent and ensures the
// metric is correct on a fresh pod that lazily picks up Redis state without
// going through MarkBroken locally. Note: we deliberately do NOT record a
// "broken" event here — these aren't new transitions, just a metric resync
// for a state that was already counted at MarkBroken time on whichever pod
// originally fired it.
for d := range domains {
metrics.SetCircuitBreakerState(serviceID, d, true)
}

// Build result
result := make(map[string]bool, len(domains))
for d := range domains {
Expand All @@ -297,8 +382,17 @@ func (cb *DomainCircuitBreaker) ClearService(ctx context.Context, serviceID stri
cb.mu.Lock()
entry, ok := cb.cache[serviceID]
count := 0
type clearedEntry struct {
domain string
reason string
}
var cleared []clearedEntry
if ok {
count = len(entry.domains)
cleared = make([]clearedEntry, 0, count)
for d, state := range entry.domains {
cleared = append(cleared, clearedEntry{domain: d, reason: state.reason})
}
delete(cb.cache, serviceID)
}
cb.mu.Unlock()
Expand All @@ -307,6 +401,15 @@ func (cb *DomainCircuitBreaker) ClearService(ctx context.Context, serviceID stri
cb.redisClient.Del(ctx, cb.keyPrefix+serviceID)
}

// Drop the gauge AND record a recovered event for every cleared domain. Each
// admin-clear represents an explicit "operator forced these domains back to
// healthy" action, distinct from natural TTL-driven recovery — the recovery
// counter still captures it because operationally it's the same transition.
for _, c := range cleared {
metrics.SetCircuitBreakerState(serviceID, c.domain, false)
metrics.RecordCircuitBreakerEvent(serviceID, c.domain, classifyCircuitBreakReason(c.reason), metrics.CircuitBreakerEventRecovered)
}

cb.logger.Info().
Str("service_id", serviceID).
Int("cleared_domains", count).
Expand Down
121 changes: 121 additions & 0 deletions gateway/domain_circuit_breaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,22 @@ import (

"github.com/pokt-network/poktroll/pkg/polylog"
"github.com/pokt-network/poktroll/pkg/polylog/polyzero"
"github.com/prometheus/client_golang/prometheus/testutil"

"github.com/pokt-network/path/metrics"
"github.com/pokt-network/path/protocol"
)

func testCircuitBreakerLogger() polylog.Logger {
return polyzero.NewLogger(polyzero.WithOutput(os.Stderr))
}

// readCircuitBreakerGauge returns the current value of the per-(serviceID, domain)
// circuit-breaker state gauge. Used by metric-transition tests.
func readCircuitBreakerGauge(serviceID, domain string) float64 {
return testutil.ToFloat64(metrics.DomainCircuitBreakerState.WithLabelValues(serviceID, domain))
}

func TestDomainCircuitBreaker_MarkAndGet(t *testing.T) {
cb := NewDomainCircuitBreaker(nil, testCircuitBreakerLogger())
ctx := context.Background()
Expand Down Expand Up @@ -427,3 +435,116 @@ func TestParseRedisValue_InvalidFormat(t *testing.T) {
t.Fatal("expected parse to fail for empty string")
}
}

// TestClassifyCircuitBreakReason verifies that the free-text reason strings
// passed to MarkBroken get bucketed into the correct bounded category for use
// as a Prometheus label. Order-sensitive prefix matches must be stable —
// "parallel_retry" must NOT match the "retry" branch.
func TestClassifyCircuitBreakReason(t *testing.T) {
cases := []struct {
raw string
expected string
}{
{"retry: heuristic_html | status=502 | response=<html>...", metrics.CircuitBreakReasonRetry},
{"batch_transport_error: connection refused", metrics.CircuitBreakReasonBatchTransport},
{"batch_heuristic: empty_array | status=200 | response=[]", metrics.CircuitBreakReasonBatchHeuristic},
{"parallel_retry: timeout after 5s", metrics.CircuitBreakReasonParallelRetry},
{"heuristic: structured error response", metrics.CircuitBreakReasonHeuristic},
{"completely unknown reason format", metrics.CircuitBreakReasonUnknown},
{"", metrics.CircuitBreakReasonUnknown},
}
for _, c := range cases {
got := classifyCircuitBreakReason(c.raw)
if got != c.expected {
t.Errorf("classifyCircuitBreakReason(%q) = %q, want %q", c.raw, got, c.expected)
}
}
}

// TestCircuitBreakerEventsCounter verifies that broken/recovered transitions
// increment the events counter with the correct reason_category bucket.
func TestCircuitBreakerEventsCounter(t *testing.T) {
cb := NewDomainCircuitBreaker(nil, testCircuitBreakerLogger())
cb.defaultTTL = 50 * time.Millisecond
ctx := context.Background()
const serviceID = "events-test-svc"
const domain = "events-test.example.com"

// Capture pre-state of the counter for the expected (service, domain, retry, broken)
// label combination so the test is hermetic against other tests' increments.
preBroken := testutil.ToFloat64(metrics.DomainCircuitBreakerEventsTotal.WithLabelValues(
serviceID, domain, metrics.CircuitBreakReasonRetry, metrics.CircuitBreakerEventBroken,
))
preRecovered := testutil.ToFloat64(metrics.DomainCircuitBreakerEventsTotal.WithLabelValues(
serviceID, domain, metrics.CircuitBreakReasonRetry, metrics.CircuitBreakerEventRecovered,
))

// MarkBroken with a "retry: ..." reason should record one broken event with
// reason_category="retry"
cb.MarkBroken(ctx, serviceID, domain, "retry: heuristic_html | status=502 | response=oops")
postBroken := testutil.ToFloat64(metrics.DomainCircuitBreakerEventsTotal.WithLabelValues(
serviceID, domain, metrics.CircuitBreakReasonRetry, metrics.CircuitBreakerEventBroken,
))
if postBroken-preBroken != 1 {
t.Fatalf("expected broken counter to increment by 1, got delta=%v", postBroken-preBroken)
}

// ClearService should record one recovered event with the same reason_category
cb.ClearService(ctx, serviceID)
postRecovered := testutil.ToFloat64(metrics.DomainCircuitBreakerEventsTotal.WithLabelValues(
serviceID, domain, metrics.CircuitBreakReasonRetry, metrics.CircuitBreakerEventRecovered,
))
if postRecovered-preRecovered != 1 {
t.Fatalf("expected recovered counter to increment by 1, got delta=%v", postRecovered-preRecovered)
}
}

// TestDomainCircuitBreaker_MetricGaugeTransitions verifies that the circuit-breaker
// state gauge moves between 0 and 1 correctly: set to 1 on MarkBroken, dropped to 0
// on ClearService, and dropped to 0 by refreshLocal once a TTL expires. We test
// directly against the Prometheus client metric rather than scraping /metrics so the
// test stays hermetic.
func TestDomainCircuitBreaker_MetricGaugeTransitions(t *testing.T) {
cb := NewDomainCircuitBreaker(nil, testCircuitBreakerLogger())
cb.defaultTTL = 50 * time.Millisecond
ctx := context.Background()
const serviceID = "metric-test-svc"
const domain = "metric-test-domain.example.com"

// Initial state: gauge should be 0 (or absent — the Set on first MarkBroken creates it).
if v := readCircuitBreakerGauge(serviceID, domain); v != 0 {
t.Fatalf("expected gauge=0 initially, got %v", v)
}

// MarkBroken → gauge should flip to 1
cb.MarkBroken(ctx, serviceID, domain, "test_reason")
if v := readCircuitBreakerGauge(serviceID, domain); v != 1 {
t.Fatalf("expected gauge=1 after MarkBroken, got %v", v)
}

// ClearService → gauge drops to 0 for cleared domains
cb.ClearService(ctx, serviceID)
if v := readCircuitBreakerGauge(serviceID, domain); v != 0 {
t.Fatalf("expected gauge=0 after ClearService, got %v", v)
}

// MarkBroken again, then wait for TTL expiry, then refresh — gauge should drop to 0
cb.MarkBroken(ctx, serviceID, domain, "test_reason_2")
if v := readCircuitBreakerGauge(serviceID, domain); v != 1 {
t.Fatalf("expected gauge=1 after second MarkBroken, got %v", v)
}

time.Sleep(60 * time.Millisecond)
// Force the cache entry to look stale so GetBrokenDomains takes the refresh path.
// In production, the same effect happens automatically once cacheTTL elapses
// (default 5s) after the entry was last refreshed.
cb.mu.Lock()
if entry, ok := cb.cache[serviceID]; ok {
entry.refreshAt = time.Now().Add(-time.Second)
}
cb.mu.Unlock()
cb.GetBrokenDomains(ctx, serviceID)
if v := readCircuitBreakerGauge(serviceID, domain); v != 0 {
t.Fatalf("expected gauge=0 after TTL expiry + refresh, got %v", v)
}
}
1 change: 1 addition & 0 deletions gateway/http_request_context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func TestIsDeceptiveResponsePattern(t *testing.T) {
{"jsonrpc_invalid_empty_array", true},
{"jsonrpc_empty_object_result", true},
{"rest_protocol_mismatch", true},
{"rest_protocol_mismatch_error", false},
{"cometbft_invalid_empty_array", true},
{"jsonrpc_fabricated_parse_error", true},
{"jsonrpc_wrapped_service_error", true},
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ require (
github.com/klauspost/cpuid/v2 v2.2.10 // indirect
github.com/kr/pretty v0.3.1 // indirect
github.com/kr/text v0.2.0 // indirect
github.com/kylelemons/godebug v1.1.0 // indirect
github.com/lib/pq v1.10.9 // indirect
github.com/linxGnu/grocksdb v1.8.14 // indirect
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect
Expand Down
31 changes: 31 additions & 0 deletions metrics/leaderboard.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,15 @@ type SupplierScoreEntry struct {
Score float64
}

// CooldownCountEntry represents per-domain count of endpoints currently in strike
// cooldown for a given service / rpc_type.
type CooldownCountEntry struct {
Domain string
RPCType string
ServiceID string
Count int
}

// LeaderboardDataProvider is an interface for getting endpoint distribution data
type LeaderboardDataProvider interface {
// GetEndpointLeaderboardData returns all endpoint entries grouped by the required dimensions
Expand All @@ -49,6 +58,10 @@ type LeaderboardDataProvider interface {
// GetSupplierScoreData returns per-(supplier, service_id) reputation scores.
// Optional: implementations may return nil if per-supplier scoring is not supported.
GetSupplierScoreData(ctx context.Context) ([]SupplierScoreEntry, error)
// GetCooldownCountData returns per-(domain, service_id, rpc_type) counts of
// endpoints currently in strike cooldown. Optional: implementations may return
// nil if cooldown tracking is not supported.
GetCooldownCountData(ctx context.Context) ([]CooldownCountEntry, error)
}

// LeaderboardPublisher publishes endpoint leaderboard metrics every 10 seconds
Expand Down Expand Up @@ -183,6 +196,24 @@ func (lp *LeaderboardPublisher) publishLeaderboard(ctx context.Context) {
}
lp.logger.Debug().Int("entries", len(supplierScores)).Msg("Published supplier scores")
}

// Publish per-domain endpoint cooldown counts. Reset between snapshots so a
// domain that drops to zero cooldown'd endpoints actually shows zero (instead
// of sticking at its last value via Prometheus' 5-min staleness window).
cooldownCounts, err := lp.provider.GetCooldownCountData(ctx)
if err != nil {
lp.logger.Warn().Err(err).Msg("Failed to get cooldown count data")
return
}

EndpointsInCooldown.Reset()

if len(cooldownCounts) > 0 {
for _, entry := range cooldownCounts {
EndpointsInCooldown.WithLabelValues(entry.Domain, entry.RPCType, entry.ServiceID).Set(float64(entry.Count))
}
lp.logger.Debug().Int("entries", len(cooldownCounts)).Msg("Published cooldown counts")
}
}

// PublishOnce can be called to manually trigger a leaderboard publish (for testing)
Expand Down
Loading
Loading