From c757bb76f32da861e52cdaac3c8b73f330c0e4a0 Mon Sep 17 00:00:00 2001 From: Hugo Nijhuis Date: Sat, 10 Jan 2026 15:33:56 +0100 Subject: [PATCH] Make configuration values injectable rather than hardcoded Add config structs with sensible defaults for tunable parameters: - JetStreamConfig for stream retention (1 year) and replica count (1) - HashRingConfig for virtual nodes per physical node (150) - ShardConfig for shard count (1024) and replication factor (1) Each component gets a new WithConfig constructor that accepts custom configuration, while the original constructors continue to work with defaults. Zero values in configs fall back to defaults for backward compatibility. Closes #38 Co-Authored-By: Claude Opus 4.5 --- cluster/cluster.go | 3 +- cluster/config_test.go | 125 +++++++++++++++++++++++++++++++++++++++ cluster/discovery.go | 2 +- cluster/distributed.go | 12 ++-- cluster/hashring.go | 26 ++++++-- cluster/hashring_test.go | 16 ++--- cluster/leader.go | 6 +- cluster/manager.go | 24 ++++---- cluster/shard.go | 31 +++++++++- cluster/types.go | 81 +++++++++++++++++-------- nats_eventbus.go | 8 +-- store/config_test.go | 46 ++++++++++++++ store/jetstream.go | 47 +++++++++++++-- 13 files changed, 353 insertions(+), 74 deletions(-) create mode 100644 cluster/config_test.go create mode 100644 store/config_test.go diff --git a/cluster/cluster.go b/cluster/cluster.go index 88a30c5..e0597f6 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -44,5 +44,4 @@ // - Leader election ensures coordination continues despite node failures // - Actor migration allows rebalancing when cluster topology changes // - Graceful shutdown with proper resource cleanup -// -package cluster \ No newline at end of file +package cluster diff --git a/cluster/config_test.go b/cluster/config_test.go new file mode 100644 index 0000000..7fbb35e --- /dev/null +++ b/cluster/config_test.go @@ -0,0 +1,125 @@ +package cluster + +import ( + "testing" +) + +func TestDefaultHashRingConfig(t *testing.T) { + config := DefaultHashRingConfig() + + if config.VirtualNodes != DefaultVirtualNodes { + t.Errorf("expected VirtualNodes=%d, got %d", DefaultVirtualNodes, config.VirtualNodes) + } +} + +func TestDefaultShardConfig(t *testing.T) { + config := DefaultShardConfig() + + if config.ShardCount != DefaultNumShards { + t.Errorf("expected ShardCount=%d, got %d", DefaultNumShards, config.ShardCount) + } + if config.ReplicationFactor != 1 { + t.Errorf("expected ReplicationFactor=1, got %d", config.ReplicationFactor) + } +} + +func TestNewConsistentHashRingWithConfig(t *testing.T) { + t.Run("custom virtual nodes", func(t *testing.T) { + config := HashRingConfig{VirtualNodes: 50} + ring := NewConsistentHashRingWithConfig(config) + + ring.AddNode("test-node") + + if len(ring.sortedHashes) != 50 { + t.Errorf("expected 50 virtual nodes, got %d", len(ring.sortedHashes)) + } + if ring.GetVirtualNodes() != 50 { + t.Errorf("expected GetVirtualNodes()=50, got %d", ring.GetVirtualNodes()) + } + }) + + t.Run("zero value uses default", func(t *testing.T) { + config := HashRingConfig{VirtualNodes: 0} + ring := NewConsistentHashRingWithConfig(config) + + ring.AddNode("test-node") + + if len(ring.sortedHashes) != DefaultVirtualNodes { + t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes)) + } + }) + + t.Run("default constructor uses default config", func(t *testing.T) { + ring := NewConsistentHashRing() + + ring.AddNode("test-node") + + if len(ring.sortedHashes) != DefaultVirtualNodes { + t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes)) + } + }) +} + +func TestNewShardManagerWithConfig(t *testing.T) { + t.Run("custom shard count", func(t *testing.T) { + config := ShardConfig{ShardCount: 256, ReplicationFactor: 2} + sm := NewShardManagerWithConfig(config) + + if sm.GetShardCount() != 256 { + t.Errorf("expected shard count 256, got %d", sm.GetShardCount()) + } + if sm.GetReplicationFactor() != 2 { + t.Errorf("expected replication factor 2, got %d", sm.GetReplicationFactor()) + } + }) + + t.Run("zero values use defaults", func(t *testing.T) { + config := ShardConfig{ShardCount: 0, ReplicationFactor: 0} + sm := NewShardManagerWithConfig(config) + + if sm.GetShardCount() != DefaultNumShards { + t.Errorf("expected shard count %d, got %d", DefaultNumShards, sm.GetShardCount()) + } + if sm.GetReplicationFactor() != 1 { + t.Errorf("expected replication factor 1, got %d", sm.GetReplicationFactor()) + } + }) + + t.Run("legacy constructor still works", func(t *testing.T) { + sm := NewShardManager(512, 3) + + if sm.GetShardCount() != 512 { + t.Errorf("expected shard count 512, got %d", sm.GetShardCount()) + } + if sm.GetReplicationFactor() != 3 { + t.Errorf("expected replication factor 3, got %d", sm.GetReplicationFactor()) + } + }) +} + +func TestShardManagerGetShard_DifferentShardCounts(t *testing.T) { + testCases := []struct { + shardCount int + }{ + {shardCount: 16}, + {shardCount: 64}, + {shardCount: 256}, + {shardCount: 1024}, + {shardCount: 4096}, + } + + for _, tc := range testCases { + t.Run("shardCount="+string(rune(tc.shardCount)), func(t *testing.T) { + sm := NewShardManagerWithConfig(ShardConfig{ShardCount: tc.shardCount}) + + // Verify all actor IDs map to valid shard range + for i := 0; i < 1000; i++ { + actorID := "actor-" + string(rune(i)) + shard := sm.GetShard(actorID) + if shard < 0 || shard >= tc.shardCount { + t.Errorf("shard %d out of range [0, %d)", shard, tc.shardCount) + } + } + }) + } +} diff --git a/cluster/discovery.go b/cluster/discovery.go index 6c2e13f..270f8d2 100644 --- a/cluster/discovery.go +++ b/cluster/discovery.go @@ -115,4 +115,4 @@ func (nd *NodeDiscovery) announceNode(updateType NodeUpdateType) { // Stop gracefully stops the node discovery service func (nd *NodeDiscovery) Stop() { nd.announceNode(NodeLeft) -} \ No newline at end of file +} diff --git a/cluster/distributed.go b/cluster/distributed.go index 365522d..b38e02c 100644 --- a/cluster/distributed.go +++ b/cluster/distributed.go @@ -200,11 +200,11 @@ func (dvm *DistributedVM) GetClusterInfo() map[string]interface{} { nodes := dvm.cluster.GetNodes() return map[string]interface{}{ - "nodeId": dvm.nodeID, - "isLeader": dvm.cluster.IsLeader(), - "leader": dvm.cluster.GetLeader(), - "nodeCount": len(nodes), - "nodes": nodes, + "nodeId": dvm.nodeID, + "isLeader": dvm.cluster.IsLeader(), + "leader": dvm.cluster.GetLeader(), + "nodeCount": len(nodes), + "nodes": nodes, } } @@ -218,4 +218,4 @@ func (dvr *DistributedVMRegistry) GetActiveVMs() map[string]interface{} { // GetShard returns the shard number for the given actor ID func (dvr *DistributedVMRegistry) GetShard(actorID string) int { return dvr.sharding.GetShard(actorID) -} \ No newline at end of file +} diff --git a/cluster/hashring.go b/cluster/hashring.go index f1c03c3..0d5f99c 100644 --- a/cluster/hashring.go +++ b/cluster/hashring.go @@ -12,13 +12,24 @@ type ConsistentHashRing struct { ring map[uint32]string // hash -> node ID sortedHashes []uint32 // sorted hash keys nodes map[string]bool // active nodes + virtualNodes int // number of virtual nodes per physical node } -// NewConsistentHashRing creates a new consistent hash ring +// NewConsistentHashRing creates a new consistent hash ring with default configuration func NewConsistentHashRing() *ConsistentHashRing { + return NewConsistentHashRingWithConfig(DefaultHashRingConfig()) +} + +// NewConsistentHashRingWithConfig creates a new consistent hash ring with custom configuration +func NewConsistentHashRingWithConfig(config HashRingConfig) *ConsistentHashRing { + virtualNodes := config.VirtualNodes + if virtualNodes == 0 { + virtualNodes = DefaultVirtualNodes + } return &ConsistentHashRing{ - ring: make(map[uint32]string), - nodes: make(map[string]bool), + ring: make(map[uint32]string), + nodes: make(map[string]bool), + virtualNodes: virtualNodes, } } @@ -31,7 +42,7 @@ func (chr *ConsistentHashRing) AddNode(nodeID string) { chr.nodes[nodeID] = true // Add virtual nodes for better distribution - for i := 0; i < VirtualNodes; i++ { + for i := 0; i < chr.virtualNodes; i++ { virtualKey := fmt.Sprintf("%s:%d", nodeID, i) hash := chr.hash(virtualKey) chr.ring[hash] = nodeID @@ -102,4 +113,9 @@ func (chr *ConsistentHashRing) GetNodes() []string { // IsEmpty returns true if the ring has no nodes func (chr *ConsistentHashRing) IsEmpty() bool { return len(chr.nodes) == 0 -} \ No newline at end of file +} + +// GetVirtualNodes returns the number of virtual nodes per physical node +func (chr *ConsistentHashRing) GetVirtualNodes() int { + return chr.virtualNodes +} diff --git a/cluster/hashring_test.go b/cluster/hashring_test.go index 18eadc6..26298f8 100644 --- a/cluster/hashring_test.go +++ b/cluster/hashring_test.go @@ -42,7 +42,7 @@ func TestAddNode(t *testing.T) { } // Verify virtual nodes were added - expectedVirtualNodes := VirtualNodes + expectedVirtualNodes := DefaultVirtualNodes if len(ring.sortedHashes) != expectedVirtualNodes { t.Errorf("expected %d virtual nodes, got %d", expectedVirtualNodes, len(ring.sortedHashes)) } @@ -86,7 +86,7 @@ func TestAddNode_MultipleNodes(t *testing.T) { t.Errorf("expected 3 nodes, got %d", len(nodes)) } - expectedHashes := VirtualNodes * 3 + expectedHashes := DefaultVirtualNodes * 3 if len(ring.sortedHashes) != expectedHashes { t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) } @@ -118,7 +118,7 @@ func TestRemoveNode(t *testing.T) { } // Verify virtual nodes were removed - expectedHashes := VirtualNodes + expectedHashes := DefaultVirtualNodes if len(ring.sortedHashes) != expectedHashes { t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) } @@ -321,7 +321,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) { } // Verify virtual nodes count - expectedHashes := numNodes * VirtualNodes + expectedHashes := numNodes * DefaultVirtualNodes if len(ring.sortedHashes) != expectedHashes { t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes)) } @@ -355,7 +355,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) { } } -func TestVirtualNodes_ImproveDistribution(t *testing.T) { +func TestDefaultVirtualNodes_ImproveDistribution(t *testing.T) { // Test that virtual nodes actually improve distribution // by comparing with a theoretical single-hash-per-node scenario @@ -386,7 +386,7 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) { stdDev := math.Sqrt(sumSquaredDiff / float64(numNodes)) coefficientOfVariation := stdDev / expectedPerNode - // With VirtualNodes=150, we expect good distribution + // With DefaultVirtualNodes=150, we expect good distribution // Coefficient of variation should be low (< 15%) if coefficientOfVariation > 0.15 { t.Errorf("distribution has high coefficient of variation: %.2f%% (expected < 15%%)", @@ -394,8 +394,8 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) { } // Verify that the actual number of virtual nodes matches expected - if len(ring.sortedHashes) != numNodes*VirtualNodes { - t.Errorf("expected %d virtual node hashes, got %d", numNodes*VirtualNodes, len(ring.sortedHashes)) + if len(ring.sortedHashes) != numNodes*DefaultVirtualNodes { + t.Errorf("expected %d virtual node hashes, got %d", numNodes*DefaultVirtualNodes, len(ring.sortedHashes)) } } diff --git a/cluster/leader.go b/cluster/leader.go index b9a2607..bf352aa 100644 --- a/cluster/leader.go +++ b/cluster/leader.go @@ -44,8 +44,8 @@ func NewLeaderElection(nodeID string, natsConn *nats.Conn, callbacks LeaderElect Bucket: "aether-leader-election", Description: "Aether cluster leader election coordination", TTL: LeaderLeaseTimeout * 2, // Auto-cleanup expired leases - MaxBytes: 1024 * 1024, // 1MB max - Replicas: 1, // Single replica for simplicity + MaxBytes: 1024 * 1024, // 1MB max + Replicas: 1, // Single replica for simplicity }) if err != nil { // Try to get existing KV store @@ -411,4 +411,4 @@ func (le *LeaderElection) updateCurrentLeader(leaderID string, term uint64) { le.callbacks.OnNewLeader(leaderID) } } -} \ No newline at end of file +} diff --git a/cluster/manager.go b/cluster/manager.go index 2ca4c40..821f469 100644 --- a/cluster/manager.go +++ b/cluster/manager.go @@ -20,17 +20,17 @@ type VMRegistry interface { // ClusterManager coordinates distributed VM operations across the cluster type ClusterManager struct { - nodeID string - nodes map[string]*NodeInfo - nodeUpdates chan NodeUpdate - shardMap *ShardMap - hashRing *ConsistentHashRing - election *LeaderElection - natsConn *nats.Conn - ctx context.Context - mutex sync.RWMutex - logger *log.Logger - vmRegistry VMRegistry // Interface to access local VMs + nodeID string + nodes map[string]*NodeInfo + nodeUpdates chan NodeUpdate + shardMap *ShardMap + hashRing *ConsistentHashRing + election *LeaderElection + natsConn *nats.Conn + ctx context.Context + mutex sync.RWMutex + logger *log.Logger + vmRegistry VMRegistry // Interface to access local VMs } // NewClusterManager creates a cluster coordination manager @@ -328,4 +328,4 @@ func (cm *ClusterManager) GetShardMap() *ShardMap { Nodes: make(map[string]NodeInfo), UpdateTime: cm.shardMap.UpdateTime, } -} \ No newline at end of file +} diff --git a/cluster/shard.go b/cluster/shard.go index 46e270d..9038815 100644 --- a/cluster/shard.go +++ b/cluster/shard.go @@ -33,8 +33,26 @@ type ShardManager struct { replication int } -// NewShardManager creates a new shard manager +// NewShardManager creates a new shard manager with default configuration func NewShardManager(shardCount, replication int) *ShardManager { + return NewShardManagerWithConfig(ShardConfig{ + ShardCount: shardCount, + ReplicationFactor: replication, + }) +} + +// NewShardManagerWithConfig creates a new shard manager with custom configuration +func NewShardManagerWithConfig(config ShardConfig) *ShardManager { + // Apply defaults for zero values + shardCount := config.ShardCount + if shardCount == 0 { + shardCount = DefaultNumShards + } + replication := config.ReplicationFactor + if replication == 0 { + replication = 1 + } + return &ShardManager{ shardCount: shardCount, shardMap: &ShardMap{Shards: make(map[int][]string), Nodes: make(map[string]NodeInfo)}, @@ -149,6 +167,15 @@ func (sm *ShardManager) GetActorsInShard(shardID int, nodeID string, vmRegistry return actors } +// GetShardCount returns the total number of shards +func (sm *ShardManager) GetShardCount() int { + return sm.shardCount +} + +// GetReplicationFactor returns the replication factor +func (sm *ShardManager) GetReplicationFactor() int { + return sm.replication +} // ConsistentHashPlacement implements PlacementStrategy using consistent hashing type ConsistentHashPlacement struct{} @@ -185,4 +212,4 @@ func (chp *ConsistentHashPlacement) RebalanceShards(currentMap *ShardMap, nodes // This is a simplified implementation // In practice, this would implement sophisticated rebalancing logic return currentMap, nil -} \ No newline at end of file +} diff --git a/cluster/types.go b/cluster/types.go index 9ec9bbf..f0a13f5 100644 --- a/cluster/types.go +++ b/cluster/types.go @@ -4,17 +4,47 @@ import ( "time" ) +// Default configuration values const ( - // NumShards defines the total number of shards in the cluster - NumShards = 1024 - // VirtualNodes defines the number of virtual nodes per physical node for consistent hashing - VirtualNodes = 150 + // DefaultNumShards defines the default total number of shards in the cluster + DefaultNumShards = 1024 + // DefaultVirtualNodes defines the default number of virtual nodes per physical node + DefaultVirtualNodes = 150 // Leadership election constants LeaderLeaseTimeout = 10 * time.Second // How long a leader lease lasts HeartbeatInterval = 3 * time.Second // How often leader sends heartbeats ElectionTimeout = 2 * time.Second // How long to wait for election ) +// HashRingConfig holds configuration for the consistent hash ring +type HashRingConfig struct { + // VirtualNodes is the number of virtual nodes per physical node (default: 150) + VirtualNodes int +} + +// DefaultHashRingConfig returns the default hash ring configuration +func DefaultHashRingConfig() HashRingConfig { + return HashRingConfig{ + VirtualNodes: DefaultVirtualNodes, + } +} + +// ShardConfig holds configuration for shard management +type ShardConfig struct { + // ShardCount is the total number of shards (default: 1024) + ShardCount int + // ReplicationFactor is the number of replicas per shard (default: 1) + ReplicationFactor int +} + +// DefaultShardConfig returns the default shard configuration +func DefaultShardConfig() ShardConfig { + return ShardConfig{ + ShardCount: DefaultNumShards, + ReplicationFactor: 1, + } +} + // NodeStatus represents the health status of a node type NodeStatus string @@ -30,14 +60,14 @@ type NodeInfo struct { Address string `json:"address"` Port int `json:"port"` Status NodeStatus `json:"status"` - Capacity float64 `json:"capacity"` // Maximum load capacity - Load float64 `json:"load"` // Current CPU/memory load - LastSeen time.Time `json:"lastSeen"` // Last heartbeat timestamp + Capacity float64 `json:"capacity"` // Maximum load capacity + Load float64 `json:"load"` // Current CPU/memory load + LastSeen time.Time `json:"lastSeen"` // Last heartbeat timestamp Timestamp time.Time `json:"timestamp"` Metadata map[string]string `json:"metadata"` IsLeader bool `json:"isLeader"` - VMCount int `json:"vmCount"` // Number of VMs on this node - ShardIDs []int `json:"shardIds"` // Shards assigned to this node + VMCount int `json:"vmCount"` // Number of VMs on this node + ShardIDs []int `json:"shardIds"` // Shards assigned to this node } // NodeUpdateType represents the type of node update @@ -57,9 +87,9 @@ type NodeUpdate struct { // ShardMap represents the distribution of shards across cluster nodes type ShardMap struct { - Version uint64 `json:"version"` // Incremented on each change - Shards map[int][]string `json:"shards"` // shard ID -> [primary, replica1, replica2] - Nodes map[string]NodeInfo `json:"nodes"` // node ID -> node info + Version uint64 `json:"version"` // Incremented on each change + Shards map[int][]string `json:"shards"` // shard ID -> [primary, replica1, replica2] + Nodes map[string]NodeInfo `json:"nodes"` // node ID -> node info UpdateTime time.Time `json:"updateTime"` } @@ -74,23 +104,23 @@ type ClusterMessage struct { // RebalanceRequest represents a request to rebalance shards type RebalanceRequest struct { - RequestID string `json:"requestId"` - FromNode string `json:"fromNode"` - ToNode string `json:"toNode"` - ShardIDs []int `json:"shardIds"` - Reason string `json:"reason"` - Migrations []ActorMigration `json:"migrations"` + RequestID string `json:"requestId"` + FromNode string `json:"fromNode"` + ToNode string `json:"toNode"` + ShardIDs []int `json:"shardIds"` + Reason string `json:"reason"` + Migrations []ActorMigration `json:"migrations"` } // ActorMigration represents the migration of an actor between nodes type ActorMigration struct { - ActorID string `json:"actorId"` - FromNode string `json:"fromNode"` - ToNode string `json:"toNode"` - ShardID int `json:"shardId"` - State map[string]interface{} `json:"state"` - Version int64 `json:"version"` - Status string `json:"status"` // "pending", "in_progress", "completed", "failed" + ActorID string `json:"actorId"` + FromNode string `json:"fromNode"` + ToNode string `json:"toNode"` + ShardID int `json:"shardId"` + State map[string]interface{} `json:"state"` + Version int64 `json:"version"` + Status string `json:"status"` // "pending", "in_progress", "completed", "failed" } // LeaderElectionCallbacks defines callbacks for leadership changes @@ -107,4 +137,3 @@ type LeadershipLease struct { ExpiresAt time.Time `json:"expiresAt"` StartedAt time.Time `json:"startedAt"` } - diff --git a/nats_eventbus.go b/nats_eventbus.go index 746bc81..a13653e 100644 --- a/nats_eventbus.go +++ b/nats_eventbus.go @@ -13,11 +13,11 @@ import ( // NATSEventBus is an EventBus that broadcasts events across all cluster nodes using NATS type NATSEventBus struct { - *EventBus // Embed base EventBus for local subscriptions - nc *nats.Conn // NATS connection + *EventBus // Embed base EventBus for local subscriptions + nc *nats.Conn // NATS connection subscriptions []*nats.Subscription - namespaceSubscribers map[string]int // Track number of subscribers per namespace - nodeID string // Unique ID for this node + namespaceSubscribers map[string]int // Track number of subscribers per namespace + nodeID string // Unique ID for this node mutex sync.Mutex ctx context.Context cancel context.CancelFunc diff --git a/store/config_test.go b/store/config_test.go new file mode 100644 index 0000000..f1714ed --- /dev/null +++ b/store/config_test.go @@ -0,0 +1,46 @@ +package store + +import ( + "testing" + "time" +) + +func TestDefaultJetStreamConfig(t *testing.T) { + config := DefaultJetStreamConfig() + + if config.StreamRetention != DefaultStreamRetention { + t.Errorf("expected StreamRetention=%v, got %v", DefaultStreamRetention, config.StreamRetention) + } + if config.ReplicaCount != DefaultReplicaCount { + t.Errorf("expected ReplicaCount=%d, got %d", DefaultReplicaCount, config.ReplicaCount) + } +} + +func TestJetStreamConfigDefaults(t *testing.T) { + t.Run("default stream retention is 1 year", func(t *testing.T) { + expected := 365 * 24 * time.Hour + if DefaultStreamRetention != expected { + t.Errorf("expected DefaultStreamRetention=%v, got %v", expected, DefaultStreamRetention) + } + }) + + t.Run("default replica count is 1", func(t *testing.T) { + if DefaultReplicaCount != 1 { + t.Errorf("expected DefaultReplicaCount=1, got %d", DefaultReplicaCount) + } + }) +} + +func TestJetStreamConfigCustomValues(t *testing.T) { + config := JetStreamConfig{ + StreamRetention: 30 * 24 * time.Hour, // 30 days + ReplicaCount: 3, + } + + if config.StreamRetention != 30*24*time.Hour { + t.Errorf("expected StreamRetention=30 days, got %v", config.StreamRetention) + } + if config.ReplicaCount != 3 { + t.Errorf("expected ReplicaCount=3, got %d", config.ReplicaCount) + } +} diff --git a/store/jetstream.go b/store/jetstream.go index b98edcf..f30967e 100644 --- a/store/jetstream.go +++ b/store/jetstream.go @@ -11,29 +11,65 @@ import ( "github.com/nats-io/nats.go" ) +// Default configuration values for JetStream event store +const ( + DefaultStreamRetention = 365 * 24 * time.Hour // 1 year + DefaultReplicaCount = 1 +) + +// JetStreamConfig holds configuration options for JetStreamEventStore +type JetStreamConfig struct { + // StreamRetention is how long to keep events (default: 1 year) + StreamRetention time.Duration + // ReplicaCount is the number of replicas for high availability (default: 1) + ReplicaCount int +} + +// DefaultJetStreamConfig returns the default configuration +func DefaultJetStreamConfig() JetStreamConfig { + return JetStreamConfig{ + StreamRetention: DefaultStreamRetention, + ReplicaCount: DefaultReplicaCount, + } +} + // JetStreamEventStore implements EventStore using NATS JetStream for persistence type JetStreamEventStore struct { js nats.JetStreamContext streamName string - mu sync.Mutex // Protects version checks during SaveEvent - versions map[string]int64 // actorID -> latest version cache + config JetStreamConfig + mu sync.Mutex // Protects version checks during SaveEvent + versions map[string]int64 // actorID -> latest version cache } -// NewJetStreamEventStore creates a new JetStream-based event store +// NewJetStreamEventStore creates a new JetStream-based event store with default configuration func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) { + return NewJetStreamEventStoreWithConfig(natsConn, streamName, DefaultJetStreamConfig()) +} + +// NewJetStreamEventStoreWithConfig creates a new JetStream-based event store with custom configuration +func NewJetStreamEventStoreWithConfig(natsConn *nats.Conn, streamName string, config JetStreamConfig) (*JetStreamEventStore, error) { js, err := natsConn.JetStream() if err != nil { return nil, fmt.Errorf("failed to get JetStream context: %w", err) } + // Apply defaults for zero values + if config.StreamRetention == 0 { + config.StreamRetention = DefaultStreamRetention + } + if config.ReplicaCount == 0 { + config.ReplicaCount = DefaultReplicaCount + } + // Create or update the stream stream := &nats.StreamConfig{ Name: streamName, Subjects: []string{fmt.Sprintf("%s.events.>", streamName), fmt.Sprintf("%s.snapshots.>", streamName)}, Storage: nats.FileStorage, Retention: nats.LimitsPolicy, - MaxAge: 365 * 24 * time.Hour, // Keep events for 1 year - Replicas: 1, // Can be increased for HA + MaxAge: config.StreamRetention, + Replicas: config.ReplicaCount, } _, err = js.AddStream(stream) @@ -44,6 +80,7 @@ func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamE return &JetStreamEventStore{ js: js, streamName: streamName, + config: config, versions: make(map[string]int64), }, nil }