[Issue #38] Make configuration values injectable rather than hardcoded #43

Merged
HugoNijhuis merged 1 commits from issue-38-injectable-config into main 2026-01-10 15:10:50 +00:00
13 changed files with 353 additions and 74 deletions

View File

@@ -44,5 +44,4 @@
// - Leader election ensures coordination continues despite node failures // - Leader election ensures coordination continues despite node failures
// - Actor migration allows rebalancing when cluster topology changes // - Actor migration allows rebalancing when cluster topology changes
// - Graceful shutdown with proper resource cleanup // - Graceful shutdown with proper resource cleanup
//
package cluster package cluster

125
cluster/config_test.go Normal file
View File

@@ -0,0 +1,125 @@
package cluster
import (
"testing"
)
func TestDefaultHashRingConfig(t *testing.T) {
config := DefaultHashRingConfig()
if config.VirtualNodes != DefaultVirtualNodes {
t.Errorf("expected VirtualNodes=%d, got %d", DefaultVirtualNodes, config.VirtualNodes)
}
}
func TestDefaultShardConfig(t *testing.T) {
config := DefaultShardConfig()
if config.ShardCount != DefaultNumShards {
t.Errorf("expected ShardCount=%d, got %d", DefaultNumShards, config.ShardCount)
}
if config.ReplicationFactor != 1 {
t.Errorf("expected ReplicationFactor=1, got %d", config.ReplicationFactor)
}
}
func TestNewConsistentHashRingWithConfig(t *testing.T) {
t.Run("custom virtual nodes", func(t *testing.T) {
config := HashRingConfig{VirtualNodes: 50}
ring := NewConsistentHashRingWithConfig(config)
ring.AddNode("test-node")
if len(ring.sortedHashes) != 50 {
t.Errorf("expected 50 virtual nodes, got %d", len(ring.sortedHashes))
}
if ring.GetVirtualNodes() != 50 {
t.Errorf("expected GetVirtualNodes()=50, got %d", ring.GetVirtualNodes())
}
})
t.Run("zero value uses default", func(t *testing.T) {
config := HashRingConfig{VirtualNodes: 0}
ring := NewConsistentHashRingWithConfig(config)
ring.AddNode("test-node")
if len(ring.sortedHashes) != DefaultVirtualNodes {
t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes))
}
})
t.Run("default constructor uses default config", func(t *testing.T) {
ring := NewConsistentHashRing()
ring.AddNode("test-node")
if len(ring.sortedHashes) != DefaultVirtualNodes {
t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes))
}
})
}
func TestNewShardManagerWithConfig(t *testing.T) {
t.Run("custom shard count", func(t *testing.T) {
config := ShardConfig{ShardCount: 256, ReplicationFactor: 2}
sm := NewShardManagerWithConfig(config)
if sm.GetShardCount() != 256 {
t.Errorf("expected shard count 256, got %d", sm.GetShardCount())
}
if sm.GetReplicationFactor() != 2 {
t.Errorf("expected replication factor 2, got %d", sm.GetReplicationFactor())
}
})
t.Run("zero values use defaults", func(t *testing.T) {
config := ShardConfig{ShardCount: 0, ReplicationFactor: 0}
sm := NewShardManagerWithConfig(config)
if sm.GetShardCount() != DefaultNumShards {
t.Errorf("expected shard count %d, got %d", DefaultNumShards, sm.GetShardCount())
}
if sm.GetReplicationFactor() != 1 {
t.Errorf("expected replication factor 1, got %d", sm.GetReplicationFactor())
}
})
t.Run("legacy constructor still works", func(t *testing.T) {
sm := NewShardManager(512, 3)
if sm.GetShardCount() != 512 {
t.Errorf("expected shard count 512, got %d", sm.GetShardCount())
}
if sm.GetReplicationFactor() != 3 {
t.Errorf("expected replication factor 3, got %d", sm.GetReplicationFactor())
}
})
}
func TestShardManagerGetShard_DifferentShardCounts(t *testing.T) {
testCases := []struct {
shardCount int
}{
{shardCount: 16},
{shardCount: 64},
{shardCount: 256},
{shardCount: 1024},
{shardCount: 4096},
}
for _, tc := range testCases {
t.Run("shardCount="+string(rune(tc.shardCount)), func(t *testing.T) {
sm := NewShardManagerWithConfig(ShardConfig{ShardCount: tc.shardCount})
// Verify all actor IDs map to valid shard range
for i := 0; i < 1000; i++ {
actorID := "actor-" + string(rune(i))
shard := sm.GetShard(actorID)
if shard < 0 || shard >= tc.shardCount {
t.Errorf("shard %d out of range [0, %d)", shard, tc.shardCount)
}
}
})
}
}

View File

@@ -200,11 +200,11 @@ func (dvm *DistributedVM) GetClusterInfo() map[string]interface{} {
nodes := dvm.cluster.GetNodes() nodes := dvm.cluster.GetNodes()
return map[string]interface{}{ return map[string]interface{}{
"nodeId": dvm.nodeID, "nodeId": dvm.nodeID,
"isLeader": dvm.cluster.IsLeader(), "isLeader": dvm.cluster.IsLeader(),
"leader": dvm.cluster.GetLeader(), "leader": dvm.cluster.GetLeader(),
"nodeCount": len(nodes), "nodeCount": len(nodes),
"nodes": nodes, "nodes": nodes,
} }
} }

View File

@@ -12,13 +12,24 @@ type ConsistentHashRing struct {
ring map[uint32]string // hash -> node ID ring map[uint32]string // hash -> node ID
sortedHashes []uint32 // sorted hash keys sortedHashes []uint32 // sorted hash keys
nodes map[string]bool // active nodes nodes map[string]bool // active nodes
virtualNodes int // number of virtual nodes per physical node
} }
// NewConsistentHashRing creates a new consistent hash ring // NewConsistentHashRing creates a new consistent hash ring with default configuration
func NewConsistentHashRing() *ConsistentHashRing { func NewConsistentHashRing() *ConsistentHashRing {
return NewConsistentHashRingWithConfig(DefaultHashRingConfig())
}
// NewConsistentHashRingWithConfig creates a new consistent hash ring with custom configuration
func NewConsistentHashRingWithConfig(config HashRingConfig) *ConsistentHashRing {
virtualNodes := config.VirtualNodes
if virtualNodes == 0 {
virtualNodes = DefaultVirtualNodes
}
return &ConsistentHashRing{ return &ConsistentHashRing{
ring: make(map[uint32]string), ring: make(map[uint32]string),
nodes: make(map[string]bool), nodes: make(map[string]bool),
virtualNodes: virtualNodes,
} }
} }
@@ -31,7 +42,7 @@ func (chr *ConsistentHashRing) AddNode(nodeID string) {
chr.nodes[nodeID] = true chr.nodes[nodeID] = true
// Add virtual nodes for better distribution // Add virtual nodes for better distribution
for i := 0; i < VirtualNodes; i++ { for i := 0; i < chr.virtualNodes; i++ {
virtualKey := fmt.Sprintf("%s:%d", nodeID, i) virtualKey := fmt.Sprintf("%s:%d", nodeID, i)
hash := chr.hash(virtualKey) hash := chr.hash(virtualKey)
chr.ring[hash] = nodeID chr.ring[hash] = nodeID
@@ -103,3 +114,8 @@ func (chr *ConsistentHashRing) GetNodes() []string {
func (chr *ConsistentHashRing) IsEmpty() bool { func (chr *ConsistentHashRing) IsEmpty() bool {
return len(chr.nodes) == 0 return len(chr.nodes) == 0
} }
// GetVirtualNodes returns the number of virtual nodes per physical node
func (chr *ConsistentHashRing) GetVirtualNodes() int {
return chr.virtualNodes
}

View File

@@ -42,7 +42,7 @@ func TestAddNode(t *testing.T) {
} }
// Verify virtual nodes were added // Verify virtual nodes were added
expectedVirtualNodes := VirtualNodes expectedVirtualNodes := DefaultVirtualNodes
if len(ring.sortedHashes) != expectedVirtualNodes { if len(ring.sortedHashes) != expectedVirtualNodes {
t.Errorf("expected %d virtual nodes, got %d", expectedVirtualNodes, len(ring.sortedHashes)) t.Errorf("expected %d virtual nodes, got %d", expectedVirtualNodes, len(ring.sortedHashes))
} }
@@ -86,7 +86,7 @@ func TestAddNode_MultipleNodes(t *testing.T) {
t.Errorf("expected 3 nodes, got %d", len(nodes)) t.Errorf("expected 3 nodes, got %d", len(nodes))
} }
expectedHashes := VirtualNodes * 3 expectedHashes := DefaultVirtualNodes * 3
if len(ring.sortedHashes) != expectedHashes { if len(ring.sortedHashes) != expectedHashes {
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
} }
@@ -118,7 +118,7 @@ func TestRemoveNode(t *testing.T) {
} }
// Verify virtual nodes were removed // Verify virtual nodes were removed
expectedHashes := VirtualNodes expectedHashes := DefaultVirtualNodes
if len(ring.sortedHashes) != expectedHashes { if len(ring.sortedHashes) != expectedHashes {
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
} }
@@ -321,7 +321,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) {
} }
// Verify virtual nodes count // Verify virtual nodes count
expectedHashes := numNodes * VirtualNodes expectedHashes := numNodes * DefaultVirtualNodes
if len(ring.sortedHashes) != expectedHashes { if len(ring.sortedHashes) != expectedHashes {
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
} }
@@ -355,7 +355,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) {
} }
} }
func TestVirtualNodes_ImproveDistribution(t *testing.T) { func TestDefaultVirtualNodes_ImproveDistribution(t *testing.T) {
// Test that virtual nodes actually improve distribution // Test that virtual nodes actually improve distribution
// by comparing with a theoretical single-hash-per-node scenario // by comparing with a theoretical single-hash-per-node scenario
@@ -386,7 +386,7 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) {
stdDev := math.Sqrt(sumSquaredDiff / float64(numNodes)) stdDev := math.Sqrt(sumSquaredDiff / float64(numNodes))
coefficientOfVariation := stdDev / expectedPerNode coefficientOfVariation := stdDev / expectedPerNode
// With VirtualNodes=150, we expect good distribution // With DefaultVirtualNodes=150, we expect good distribution
// Coefficient of variation should be low (< 15%) // Coefficient of variation should be low (< 15%)
if coefficientOfVariation > 0.15 { if coefficientOfVariation > 0.15 {
t.Errorf("distribution has high coefficient of variation: %.2f%% (expected < 15%%)", t.Errorf("distribution has high coefficient of variation: %.2f%% (expected < 15%%)",
@@ -394,8 +394,8 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) {
} }
// Verify that the actual number of virtual nodes matches expected // Verify that the actual number of virtual nodes matches expected
if len(ring.sortedHashes) != numNodes*VirtualNodes { if len(ring.sortedHashes) != numNodes*DefaultVirtualNodes {
t.Errorf("expected %d virtual node hashes, got %d", numNodes*VirtualNodes, len(ring.sortedHashes)) t.Errorf("expected %d virtual node hashes, got %d", numNodes*DefaultVirtualNodes, len(ring.sortedHashes))
} }
} }

View File

@@ -44,8 +44,8 @@ func NewLeaderElection(nodeID string, natsConn *nats.Conn, callbacks LeaderElect
Bucket: "aether-leader-election", Bucket: "aether-leader-election",
Description: "Aether cluster leader election coordination", Description: "Aether cluster leader election coordination",
TTL: LeaderLeaseTimeout * 2, // Auto-cleanup expired leases TTL: LeaderLeaseTimeout * 2, // Auto-cleanup expired leases
MaxBytes: 1024 * 1024, // 1MB max MaxBytes: 1024 * 1024, // 1MB max
Replicas: 1, // Single replica for simplicity Replicas: 1, // Single replica for simplicity
}) })
if err != nil { if err != nil {
// Try to get existing KV store // Try to get existing KV store

View File

@@ -20,17 +20,17 @@ type VMRegistry interface {
// ClusterManager coordinates distributed VM operations across the cluster // ClusterManager coordinates distributed VM operations across the cluster
type ClusterManager struct { type ClusterManager struct {
nodeID string nodeID string
nodes map[string]*NodeInfo nodes map[string]*NodeInfo
nodeUpdates chan NodeUpdate nodeUpdates chan NodeUpdate
shardMap *ShardMap shardMap *ShardMap
hashRing *ConsistentHashRing hashRing *ConsistentHashRing
election *LeaderElection election *LeaderElection
natsConn *nats.Conn natsConn *nats.Conn
ctx context.Context ctx context.Context
mutex sync.RWMutex mutex sync.RWMutex
logger *log.Logger logger *log.Logger
vmRegistry VMRegistry // Interface to access local VMs vmRegistry VMRegistry // Interface to access local VMs
} }
// NewClusterManager creates a cluster coordination manager // NewClusterManager creates a cluster coordination manager

View File

@@ -33,8 +33,26 @@ type ShardManager struct {
replication int replication int
} }
// NewShardManager creates a new shard manager // NewShardManager creates a new shard manager with default configuration
func NewShardManager(shardCount, replication int) *ShardManager { func NewShardManager(shardCount, replication int) *ShardManager {
return NewShardManagerWithConfig(ShardConfig{
ShardCount: shardCount,
ReplicationFactor: replication,
})
}
// NewShardManagerWithConfig creates a new shard manager with custom configuration
func NewShardManagerWithConfig(config ShardConfig) *ShardManager {
// Apply defaults for zero values
shardCount := config.ShardCount
if shardCount == 0 {
shardCount = DefaultNumShards
}
replication := config.ReplicationFactor
if replication == 0 {
replication = 1
}
return &ShardManager{ return &ShardManager{
shardCount: shardCount, shardCount: shardCount,
shardMap: &ShardMap{Shards: make(map[int][]string), Nodes: make(map[string]NodeInfo)}, shardMap: &ShardMap{Shards: make(map[int][]string), Nodes: make(map[string]NodeInfo)},
@@ -149,6 +167,15 @@ func (sm *ShardManager) GetActorsInShard(shardID int, nodeID string, vmRegistry
return actors return actors
} }
// GetShardCount returns the total number of shards
func (sm *ShardManager) GetShardCount() int {
return sm.shardCount
}
// GetReplicationFactor returns the replication factor
func (sm *ShardManager) GetReplicationFactor() int {
return sm.replication
}
// ConsistentHashPlacement implements PlacementStrategy using consistent hashing // ConsistentHashPlacement implements PlacementStrategy using consistent hashing
type ConsistentHashPlacement struct{} type ConsistentHashPlacement struct{}

View File

@@ -4,17 +4,47 @@ import (
"time" "time"
) )
// Default configuration values
const ( const (
// NumShards defines the total number of shards in the cluster // DefaultNumShards defines the default total number of shards in the cluster
NumShards = 1024 DefaultNumShards = 1024
// VirtualNodes defines the number of virtual nodes per physical node for consistent hashing // DefaultVirtualNodes defines the default number of virtual nodes per physical node
VirtualNodes = 150 DefaultVirtualNodes = 150
// Leadership election constants // Leadership election constants
LeaderLeaseTimeout = 10 * time.Second // How long a leader lease lasts LeaderLeaseTimeout = 10 * time.Second // How long a leader lease lasts
HeartbeatInterval = 3 * time.Second // How often leader sends heartbeats HeartbeatInterval = 3 * time.Second // How often leader sends heartbeats
ElectionTimeout = 2 * time.Second // How long to wait for election ElectionTimeout = 2 * time.Second // How long to wait for election
) )
// HashRingConfig holds configuration for the consistent hash ring
type HashRingConfig struct {
// VirtualNodes is the number of virtual nodes per physical node (default: 150)
VirtualNodes int
}
// DefaultHashRingConfig returns the default hash ring configuration
func DefaultHashRingConfig() HashRingConfig {
return HashRingConfig{
VirtualNodes: DefaultVirtualNodes,
}
}
// ShardConfig holds configuration for shard management
type ShardConfig struct {
// ShardCount is the total number of shards (default: 1024)
ShardCount int
// ReplicationFactor is the number of replicas per shard (default: 1)
ReplicationFactor int
}
// DefaultShardConfig returns the default shard configuration
func DefaultShardConfig() ShardConfig {
return ShardConfig{
ShardCount: DefaultNumShards,
ReplicationFactor: 1,
}
}
// NodeStatus represents the health status of a node // NodeStatus represents the health status of a node
type NodeStatus string type NodeStatus string
@@ -30,14 +60,14 @@ type NodeInfo struct {
Address string `json:"address"` Address string `json:"address"`
Port int `json:"port"` Port int `json:"port"`
Status NodeStatus `json:"status"` Status NodeStatus `json:"status"`
Capacity float64 `json:"capacity"` // Maximum load capacity Capacity float64 `json:"capacity"` // Maximum load capacity
Load float64 `json:"load"` // Current CPU/memory load Load float64 `json:"load"` // Current CPU/memory load
LastSeen time.Time `json:"lastSeen"` // Last heartbeat timestamp LastSeen time.Time `json:"lastSeen"` // Last heartbeat timestamp
Timestamp time.Time `json:"timestamp"` Timestamp time.Time `json:"timestamp"`
Metadata map[string]string `json:"metadata"` Metadata map[string]string `json:"metadata"`
IsLeader bool `json:"isLeader"` IsLeader bool `json:"isLeader"`
VMCount int `json:"vmCount"` // Number of VMs on this node VMCount int `json:"vmCount"` // Number of VMs on this node
ShardIDs []int `json:"shardIds"` // Shards assigned to this node ShardIDs []int `json:"shardIds"` // Shards assigned to this node
} }
// NodeUpdateType represents the type of node update // NodeUpdateType represents the type of node update
@@ -57,9 +87,9 @@ type NodeUpdate struct {
// ShardMap represents the distribution of shards across cluster nodes // ShardMap represents the distribution of shards across cluster nodes
type ShardMap struct { type ShardMap struct {
Version uint64 `json:"version"` // Incremented on each change Version uint64 `json:"version"` // Incremented on each change
Shards map[int][]string `json:"shards"` // shard ID -> [primary, replica1, replica2] Shards map[int][]string `json:"shards"` // shard ID -> [primary, replica1, replica2]
Nodes map[string]NodeInfo `json:"nodes"` // node ID -> node info Nodes map[string]NodeInfo `json:"nodes"` // node ID -> node info
UpdateTime time.Time `json:"updateTime"` UpdateTime time.Time `json:"updateTime"`
} }
@@ -74,23 +104,23 @@ type ClusterMessage struct {
// RebalanceRequest represents a request to rebalance shards // RebalanceRequest represents a request to rebalance shards
type RebalanceRequest struct { type RebalanceRequest struct {
RequestID string `json:"requestId"` RequestID string `json:"requestId"`
FromNode string `json:"fromNode"` FromNode string `json:"fromNode"`
ToNode string `json:"toNode"` ToNode string `json:"toNode"`
ShardIDs []int `json:"shardIds"` ShardIDs []int `json:"shardIds"`
Reason string `json:"reason"` Reason string `json:"reason"`
Migrations []ActorMigration `json:"migrations"` Migrations []ActorMigration `json:"migrations"`
} }
// ActorMigration represents the migration of an actor between nodes // ActorMigration represents the migration of an actor between nodes
type ActorMigration struct { type ActorMigration struct {
ActorID string `json:"actorId"` ActorID string `json:"actorId"`
FromNode string `json:"fromNode"` FromNode string `json:"fromNode"`
ToNode string `json:"toNode"` ToNode string `json:"toNode"`
ShardID int `json:"shardId"` ShardID int `json:"shardId"`
State map[string]interface{} `json:"state"` State map[string]interface{} `json:"state"`
Version int64 `json:"version"` Version int64 `json:"version"`
Status string `json:"status"` // "pending", "in_progress", "completed", "failed" Status string `json:"status"` // "pending", "in_progress", "completed", "failed"
} }
// LeaderElectionCallbacks defines callbacks for leadership changes // LeaderElectionCallbacks defines callbacks for leadership changes
@@ -107,4 +137,3 @@ type LeadershipLease struct {
ExpiresAt time.Time `json:"expiresAt"` ExpiresAt time.Time `json:"expiresAt"`
StartedAt time.Time `json:"startedAt"` StartedAt time.Time `json:"startedAt"`
} }

View File

@@ -13,11 +13,11 @@ import (
// NATSEventBus is an EventBus that broadcasts events across all cluster nodes using NATS // NATSEventBus is an EventBus that broadcasts events across all cluster nodes using NATS
type NATSEventBus struct { type NATSEventBus struct {
*EventBus // Embed base EventBus for local subscriptions *EventBus // Embed base EventBus for local subscriptions
nc *nats.Conn // NATS connection nc *nats.Conn // NATS connection
subscriptions []*nats.Subscription subscriptions []*nats.Subscription
namespaceSubscribers map[string]int // Track number of subscribers per namespace namespaceSubscribers map[string]int // Track number of subscribers per namespace
nodeID string // Unique ID for this node nodeID string // Unique ID for this node
mutex sync.Mutex mutex sync.Mutex
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc

46
store/config_test.go Normal file
View File

@@ -0,0 +1,46 @@
package store
import (
"testing"
"time"
)
func TestDefaultJetStreamConfig(t *testing.T) {
config := DefaultJetStreamConfig()
if config.StreamRetention != DefaultStreamRetention {
t.Errorf("expected StreamRetention=%v, got %v", DefaultStreamRetention, config.StreamRetention)
}
if config.ReplicaCount != DefaultReplicaCount {
t.Errorf("expected ReplicaCount=%d, got %d", DefaultReplicaCount, config.ReplicaCount)
}
}
func TestJetStreamConfigDefaults(t *testing.T) {
t.Run("default stream retention is 1 year", func(t *testing.T) {
expected := 365 * 24 * time.Hour
if DefaultStreamRetention != expected {
t.Errorf("expected DefaultStreamRetention=%v, got %v", expected, DefaultStreamRetention)
}
})
t.Run("default replica count is 1", func(t *testing.T) {
if DefaultReplicaCount != 1 {
t.Errorf("expected DefaultReplicaCount=1, got %d", DefaultReplicaCount)
}
})
}
func TestJetStreamConfigCustomValues(t *testing.T) {
config := JetStreamConfig{
StreamRetention: 30 * 24 * time.Hour, // 30 days
ReplicaCount: 3,
}
if config.StreamRetention != 30*24*time.Hour {
t.Errorf("expected StreamRetention=30 days, got %v", config.StreamRetention)
}
if config.ReplicaCount != 3 {
t.Errorf("expected ReplicaCount=3, got %d", config.ReplicaCount)
}
}

View File

@@ -11,29 +11,65 @@ import (
"github.com/nats-io/nats.go" "github.com/nats-io/nats.go"
) )
// Default configuration values for JetStream event store
const (
DefaultStreamRetention = 365 * 24 * time.Hour // 1 year
DefaultReplicaCount = 1
)
// JetStreamConfig holds configuration options for JetStreamEventStore
type JetStreamConfig struct {
// StreamRetention is how long to keep events (default: 1 year)
StreamRetention time.Duration
// ReplicaCount is the number of replicas for high availability (default: 1)
ReplicaCount int
}
// DefaultJetStreamConfig returns the default configuration
func DefaultJetStreamConfig() JetStreamConfig {
return JetStreamConfig{
StreamRetention: DefaultStreamRetention,
ReplicaCount: DefaultReplicaCount,
}
}
// JetStreamEventStore implements EventStore using NATS JetStream for persistence // JetStreamEventStore implements EventStore using NATS JetStream for persistence
type JetStreamEventStore struct { type JetStreamEventStore struct {
js nats.JetStreamContext js nats.JetStreamContext
streamName string streamName string
mu sync.Mutex // Protects version checks during SaveEvent config JetStreamConfig
versions map[string]int64 // actorID -> latest version cache mu sync.Mutex // Protects version checks during SaveEvent
versions map[string]int64 // actorID -> latest version cache
} }
// NewJetStreamEventStore creates a new JetStream-based event store // NewJetStreamEventStore creates a new JetStream-based event store with default configuration
func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) { func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) {
return NewJetStreamEventStoreWithConfig(natsConn, streamName, DefaultJetStreamConfig())
}
// NewJetStreamEventStoreWithConfig creates a new JetStream-based event store with custom configuration
func NewJetStreamEventStoreWithConfig(natsConn *nats.Conn, streamName string, config JetStreamConfig) (*JetStreamEventStore, error) {
js, err := natsConn.JetStream() js, err := natsConn.JetStream()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get JetStream context: %w", err) return nil, fmt.Errorf("failed to get JetStream context: %w", err)
} }
// Apply defaults for zero values
if config.StreamRetention == 0 {
config.StreamRetention = DefaultStreamRetention
}
if config.ReplicaCount == 0 {
config.ReplicaCount = DefaultReplicaCount
}
// Create or update the stream // Create or update the stream
stream := &nats.StreamConfig{ stream := &nats.StreamConfig{
Name: streamName, Name: streamName,
Subjects: []string{fmt.Sprintf("%s.events.>", streamName), fmt.Sprintf("%s.snapshots.>", streamName)}, Subjects: []string{fmt.Sprintf("%s.events.>", streamName), fmt.Sprintf("%s.snapshots.>", streamName)},
Storage: nats.FileStorage, Storage: nats.FileStorage,
Retention: nats.LimitsPolicy, Retention: nats.LimitsPolicy,
MaxAge: 365 * 24 * time.Hour, // Keep events for 1 year MaxAge: config.StreamRetention,
Replicas: 1, // Can be increased for HA Replicas: config.ReplicaCount,
} }
_, err = js.AddStream(stream) _, err = js.AddStream(stream)
@@ -44,6 +80,7 @@ func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamE
return &JetStreamEventStore{ return &JetStreamEventStore{
js: js, js: js,
streamName: streamName, streamName: streamName,
config: config,
versions: make(map[string]int64), versions: make(map[string]int64),
}, nil }, nil
} }