Compare commits
5 Commits
b630258f60
...
e77a3a9868
| Author | SHA1 | Date | |
|---|---|---|---|
|
e77a3a9868
|
|||
| 8df36cac7a | |||
|
b759c7fb97
|
|||
|
eaff315782
|
|||
|
c757bb76f3
|
@@ -44,5 +44,4 @@
|
||||
// - Leader election ensures coordination continues despite node failures
|
||||
// - Actor migration allows rebalancing when cluster topology changes
|
||||
// - Graceful shutdown with proper resource cleanup
|
||||
//
|
||||
package cluster
|
||||
125
cluster/config_test.go
Normal file
125
cluster/config_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestDefaultHashRingConfig(t *testing.T) {
|
||||
config := DefaultHashRingConfig()
|
||||
|
||||
if config.VirtualNodes != DefaultVirtualNodes {
|
||||
t.Errorf("expected VirtualNodes=%d, got %d", DefaultVirtualNodes, config.VirtualNodes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultShardConfig(t *testing.T) {
|
||||
config := DefaultShardConfig()
|
||||
|
||||
if config.ShardCount != DefaultNumShards {
|
||||
t.Errorf("expected ShardCount=%d, got %d", DefaultNumShards, config.ShardCount)
|
||||
}
|
||||
if config.ReplicationFactor != 1 {
|
||||
t.Errorf("expected ReplicationFactor=1, got %d", config.ReplicationFactor)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewConsistentHashRingWithConfig(t *testing.T) {
|
||||
t.Run("custom virtual nodes", func(t *testing.T) {
|
||||
config := HashRingConfig{VirtualNodes: 50}
|
||||
ring := NewConsistentHashRingWithConfig(config)
|
||||
|
||||
ring.AddNode("test-node")
|
||||
|
||||
if len(ring.sortedHashes) != 50 {
|
||||
t.Errorf("expected 50 virtual nodes, got %d", len(ring.sortedHashes))
|
||||
}
|
||||
if ring.GetVirtualNodes() != 50 {
|
||||
t.Errorf("expected GetVirtualNodes()=50, got %d", ring.GetVirtualNodes())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero value uses default", func(t *testing.T) {
|
||||
config := HashRingConfig{VirtualNodes: 0}
|
||||
ring := NewConsistentHashRingWithConfig(config)
|
||||
|
||||
ring.AddNode("test-node")
|
||||
|
||||
if len(ring.sortedHashes) != DefaultVirtualNodes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes))
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default constructor uses default config", func(t *testing.T) {
|
||||
ring := NewConsistentHashRing()
|
||||
|
||||
ring.AddNode("test-node")
|
||||
|
||||
if len(ring.sortedHashes) != DefaultVirtualNodes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", DefaultVirtualNodes, len(ring.sortedHashes))
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNewShardManagerWithConfig(t *testing.T) {
|
||||
t.Run("custom shard count", func(t *testing.T) {
|
||||
config := ShardConfig{ShardCount: 256, ReplicationFactor: 2}
|
||||
sm := NewShardManagerWithConfig(config)
|
||||
|
||||
if sm.GetShardCount() != 256 {
|
||||
t.Errorf("expected shard count 256, got %d", sm.GetShardCount())
|
||||
}
|
||||
if sm.GetReplicationFactor() != 2 {
|
||||
t.Errorf("expected replication factor 2, got %d", sm.GetReplicationFactor())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("zero values use defaults", func(t *testing.T) {
|
||||
config := ShardConfig{ShardCount: 0, ReplicationFactor: 0}
|
||||
sm := NewShardManagerWithConfig(config)
|
||||
|
||||
if sm.GetShardCount() != DefaultNumShards {
|
||||
t.Errorf("expected shard count %d, got %d", DefaultNumShards, sm.GetShardCount())
|
||||
}
|
||||
if sm.GetReplicationFactor() != 1 {
|
||||
t.Errorf("expected replication factor 1, got %d", sm.GetReplicationFactor())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("legacy constructor still works", func(t *testing.T) {
|
||||
sm := NewShardManager(512, 3)
|
||||
|
||||
if sm.GetShardCount() != 512 {
|
||||
t.Errorf("expected shard count 512, got %d", sm.GetShardCount())
|
||||
}
|
||||
if sm.GetReplicationFactor() != 3 {
|
||||
t.Errorf("expected replication factor 3, got %d", sm.GetReplicationFactor())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestShardManagerGetShard_DifferentShardCounts(t *testing.T) {
|
||||
testCases := []struct {
|
||||
shardCount int
|
||||
}{
|
||||
{shardCount: 16},
|
||||
{shardCount: 64},
|
||||
{shardCount: 256},
|
||||
{shardCount: 1024},
|
||||
{shardCount: 4096},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run("shardCount="+string(rune(tc.shardCount)), func(t *testing.T) {
|
||||
sm := NewShardManagerWithConfig(ShardConfig{ShardCount: tc.shardCount})
|
||||
|
||||
// Verify all actor IDs map to valid shard range
|
||||
for i := 0; i < 1000; i++ {
|
||||
actorID := "actor-" + string(rune(i))
|
||||
shard := sm.GetShard(actorID)
|
||||
if shard < 0 || shard >= tc.shardCount {
|
||||
t.Errorf("shard %d out of range [0, %d)", shard, tc.shardCount)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ import (
|
||||
type DistributedVM struct {
|
||||
nodeID string
|
||||
cluster *ClusterManager
|
||||
localRuntime Runtime // Interface to avoid import cycles
|
||||
localRuntime Runtime
|
||||
sharding *ShardManager
|
||||
discovery *NodeDiscovery
|
||||
natsConn *nats.Conn
|
||||
@@ -20,19 +20,31 @@ type DistributedVM struct {
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
// Runtime interface to avoid import cycles with main aether package
|
||||
// Runtime defines the interface for a local runtime that executes actors.
|
||||
// This interface decouples the cluster package from specific runtime implementations.
|
||||
type Runtime interface {
|
||||
// Start initializes and starts the runtime
|
||||
Start() error
|
||||
LoadModel(model interface{}) error
|
||||
SendMessage(message interface{}) error
|
||||
// LoadModel loads an EventStorming model into the runtime
|
||||
LoadModel(model RuntimeModel) error
|
||||
// SendMessage sends a message to an actor in the runtime
|
||||
SendMessage(message RuntimeMessage) error
|
||||
}
|
||||
|
||||
// DistributedVMRegistry implements VMRegistry using DistributedVM's local runtime and sharding
|
||||
// DistributedVMRegistry implements VMRegistry using DistributedVM's local runtime and sharding.
|
||||
// It provides the cluster manager with access to VM information without import cycles.
|
||||
type DistributedVMRegistry struct {
|
||||
runtime interface{} // Runtime interface to avoid import cycles
|
||||
vmProvider VMProvider
|
||||
sharding *ShardManager
|
||||
}
|
||||
|
||||
// VMProvider defines an interface for accessing VMs from a runtime.
|
||||
// This is used by DistributedVMRegistry to get VM information.
|
||||
type VMProvider interface {
|
||||
// GetActiveVMs returns a map of actor IDs to their VirtualMachine instances
|
||||
GetActiveVMs() map[string]VirtualMachine
|
||||
}
|
||||
|
||||
// NewDistributedVM creates a distributed VM runtime cluster node
|
||||
func NewDistributedVM(nodeID string, natsURLs []string, localRuntime Runtime) (*DistributedVM, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -67,16 +79,19 @@ func NewDistributedVM(nodeID string, natsURLs []string, localRuntime Runtime) (*
|
||||
cancel: cancel,
|
||||
}
|
||||
|
||||
// Create VM registry and connect it to cluster manager
|
||||
vmRegistry := &DistributedVMRegistry{
|
||||
runtime: localRuntime,
|
||||
sharding: sharding,
|
||||
}
|
||||
cluster.SetVMRegistry(vmRegistry)
|
||||
|
||||
return dvm, nil
|
||||
}
|
||||
|
||||
// SetVMProvider sets the VM provider for the distributed VM registry.
|
||||
// This should be called after the runtime is fully initialized.
|
||||
func (dvm *DistributedVM) SetVMProvider(provider VMProvider) {
|
||||
vmRegistry := &DistributedVMRegistry{
|
||||
vmProvider: provider,
|
||||
sharding: dvm.sharding,
|
||||
}
|
||||
dvm.cluster.SetVMRegistry(vmRegistry)
|
||||
}
|
||||
|
||||
// Start begins the distributed VM cluster node
|
||||
func (dvm *DistributedVM) Start() error {
|
||||
// Start local runtime
|
||||
@@ -103,7 +118,7 @@ func (dvm *DistributedVM) Stop() {
|
||||
}
|
||||
|
||||
// LoadModel distributes EventStorming model across the cluster with VM templates
|
||||
func (dvm *DistributedVM) LoadModel(model interface{}) error {
|
||||
func (dvm *DistributedVM) LoadModel(model RuntimeModel) error {
|
||||
// Load model locally first
|
||||
if err := dvm.localRuntime.LoadModel(model); err != nil {
|
||||
return fmt.Errorf("failed to load model locally: %w", err)
|
||||
@@ -121,7 +136,7 @@ func (dvm *DistributedVM) LoadModel(model interface{}) error {
|
||||
}
|
||||
|
||||
// SendMessage routes messages across the distributed cluster
|
||||
func (dvm *DistributedVM) SendMessage(message interface{}) error {
|
||||
func (dvm *DistributedVM) SendMessage(message RuntimeMessage) error {
|
||||
// This is a simplified implementation
|
||||
// In practice, this would determine the target node based on sharding
|
||||
// and route the message appropriately
|
||||
@@ -162,15 +177,29 @@ func (dvm *DistributedVM) handleClusterMessage(msg *nats.Msg) {
|
||||
switch clusterMsg.Type {
|
||||
case "load_model":
|
||||
// Handle model loading from other nodes
|
||||
if model := clusterMsg.Payload; model != nil {
|
||||
dvm.localRuntime.LoadModel(model)
|
||||
// Re-marshal and unmarshal to convert map[string]interface{} to concrete type
|
||||
payloadBytes, err := json.Marshal(clusterMsg.Payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var model ModelPayload
|
||||
if err := json.Unmarshal(payloadBytes, &model); err != nil {
|
||||
return
|
||||
}
|
||||
dvm.localRuntime.LoadModel(&model)
|
||||
|
||||
case "route_message":
|
||||
// Handle message routing from other nodes
|
||||
if message := clusterMsg.Payload; message != nil {
|
||||
dvm.localRuntime.SendMessage(message)
|
||||
// Re-marshal and unmarshal to convert map[string]interface{} to concrete type
|
||||
payloadBytes, err := json.Marshal(clusterMsg.Payload)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var message MessagePayload
|
||||
if err := json.Unmarshal(payloadBytes, &message); err != nil {
|
||||
return
|
||||
}
|
||||
dvm.localRuntime.SendMessage(&message)
|
||||
|
||||
case "rebalance":
|
||||
// Handle shard rebalancing requests
|
||||
@@ -208,11 +237,12 @@ func (dvm *DistributedVM) GetClusterInfo() map[string]interface{} {
|
||||
}
|
||||
}
|
||||
|
||||
// GetActiveVMs returns a map of active VMs (implementation depends on runtime)
|
||||
func (dvr *DistributedVMRegistry) GetActiveVMs() map[string]interface{} {
|
||||
// This would need to access the actual runtime's VM registry
|
||||
// For now, return empty map to avoid import cycles
|
||||
return make(map[string]interface{})
|
||||
// GetActiveVMs returns a map of active VMs from the VM provider
|
||||
func (dvr *DistributedVMRegistry) GetActiveVMs() map[string]VirtualMachine {
|
||||
if dvr.vmProvider == nil {
|
||||
return make(map[string]VirtualMachine)
|
||||
}
|
||||
return dvr.vmProvider.GetActiveVMs()
|
||||
}
|
||||
|
||||
// GetShard returns the shard number for the given actor ID
|
||||
|
||||
@@ -12,13 +12,24 @@ type ConsistentHashRing struct {
|
||||
ring map[uint32]string // hash -> node ID
|
||||
sortedHashes []uint32 // sorted hash keys
|
||||
nodes map[string]bool // active nodes
|
||||
virtualNodes int // number of virtual nodes per physical node
|
||||
}
|
||||
|
||||
// NewConsistentHashRing creates a new consistent hash ring
|
||||
// NewConsistentHashRing creates a new consistent hash ring with default configuration
|
||||
func NewConsistentHashRing() *ConsistentHashRing {
|
||||
return NewConsistentHashRingWithConfig(DefaultHashRingConfig())
|
||||
}
|
||||
|
||||
// NewConsistentHashRingWithConfig creates a new consistent hash ring with custom configuration
|
||||
func NewConsistentHashRingWithConfig(config HashRingConfig) *ConsistentHashRing {
|
||||
virtualNodes := config.VirtualNodes
|
||||
if virtualNodes == 0 {
|
||||
virtualNodes = DefaultVirtualNodes
|
||||
}
|
||||
return &ConsistentHashRing{
|
||||
ring: make(map[uint32]string),
|
||||
nodes: make(map[string]bool),
|
||||
virtualNodes: virtualNodes,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -31,7 +42,7 @@ func (chr *ConsistentHashRing) AddNode(nodeID string) {
|
||||
chr.nodes[nodeID] = true
|
||||
|
||||
// Add virtual nodes for better distribution
|
||||
for i := 0; i < VirtualNodes; i++ {
|
||||
for i := 0; i < chr.virtualNodes; i++ {
|
||||
virtualKey := fmt.Sprintf("%s:%d", nodeID, i)
|
||||
hash := chr.hash(virtualKey)
|
||||
chr.ring[hash] = nodeID
|
||||
@@ -103,3 +114,8 @@ func (chr *ConsistentHashRing) GetNodes() []string {
|
||||
func (chr *ConsistentHashRing) IsEmpty() bool {
|
||||
return len(chr.nodes) == 0
|
||||
}
|
||||
|
||||
// GetVirtualNodes returns the number of virtual nodes per physical node
|
||||
func (chr *ConsistentHashRing) GetVirtualNodes() int {
|
||||
return chr.virtualNodes
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func TestAddNode(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify virtual nodes were added
|
||||
expectedVirtualNodes := VirtualNodes
|
||||
expectedVirtualNodes := DefaultVirtualNodes
|
||||
if len(ring.sortedHashes) != expectedVirtualNodes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", expectedVirtualNodes, len(ring.sortedHashes))
|
||||
}
|
||||
@@ -86,7 +86,7 @@ func TestAddNode_MultipleNodes(t *testing.T) {
|
||||
t.Errorf("expected 3 nodes, got %d", len(nodes))
|
||||
}
|
||||
|
||||
expectedHashes := VirtualNodes * 3
|
||||
expectedHashes := DefaultVirtualNodes * 3
|
||||
if len(ring.sortedHashes) != expectedHashes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
|
||||
}
|
||||
@@ -118,7 +118,7 @@ func TestRemoveNode(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify virtual nodes were removed
|
||||
expectedHashes := VirtualNodes
|
||||
expectedHashes := DefaultVirtualNodes
|
||||
if len(ring.sortedHashes) != expectedHashes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
|
||||
}
|
||||
@@ -321,7 +321,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify virtual nodes count
|
||||
expectedHashes := numNodes * VirtualNodes
|
||||
expectedHashes := numNodes * DefaultVirtualNodes
|
||||
if len(ring.sortedHashes) != expectedHashes {
|
||||
t.Errorf("expected %d virtual nodes, got %d", expectedHashes, len(ring.sortedHashes))
|
||||
}
|
||||
@@ -355,7 +355,7 @@ func TestRingBehavior_ManyNodes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestVirtualNodes_ImproveDistribution(t *testing.T) {
|
||||
func TestDefaultVirtualNodes_ImproveDistribution(t *testing.T) {
|
||||
// Test that virtual nodes actually improve distribution
|
||||
// by comparing with a theoretical single-hash-per-node scenario
|
||||
|
||||
@@ -386,7 +386,7 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) {
|
||||
stdDev := math.Sqrt(sumSquaredDiff / float64(numNodes))
|
||||
coefficientOfVariation := stdDev / expectedPerNode
|
||||
|
||||
// With VirtualNodes=150, we expect good distribution
|
||||
// With DefaultVirtualNodes=150, we expect good distribution
|
||||
// Coefficient of variation should be low (< 15%)
|
||||
if coefficientOfVariation > 0.15 {
|
||||
t.Errorf("distribution has high coefficient of variation: %.2f%% (expected < 15%%)",
|
||||
@@ -394,8 +394,8 @@ func TestVirtualNodes_ImproveDistribution(t *testing.T) {
|
||||
}
|
||||
|
||||
// Verify that the actual number of virtual nodes matches expected
|
||||
if len(ring.sortedHashes) != numNodes*VirtualNodes {
|
||||
t.Errorf("expected %d virtual node hashes, got %d", numNodes*VirtualNodes, len(ring.sortedHashes))
|
||||
if len(ring.sortedHashes) != numNodes*DefaultVirtualNodes {
|
||||
t.Errorf("expected %d virtual node hashes, got %d", numNodes*DefaultVirtualNodes, len(ring.sortedHashes))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,9 +12,12 @@ import (
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// VMRegistry provides access to local VM information for cluster operations
|
||||
// VMRegistry provides access to local VM information for cluster operations.
|
||||
// Implementations must provide thread-safe access to VM data.
|
||||
type VMRegistry interface {
|
||||
GetActiveVMs() map[string]interface{} // VirtualMachine interface to avoid import cycles
|
||||
// GetActiveVMs returns a map of actor IDs to their VirtualMachine instances
|
||||
GetActiveVMs() map[string]VirtualMachine
|
||||
// GetShard returns the shard number for a given actor ID
|
||||
GetShard(actorID string) int
|
||||
}
|
||||
|
||||
@@ -50,13 +53,13 @@ func NewClusterManager(nodeID string, natsConn *nats.Conn, ctx context.Context)
|
||||
// Create leadership election with callbacks
|
||||
callbacks := LeaderElectionCallbacks{
|
||||
OnBecameLeader: func() {
|
||||
cm.logger.Printf("👑 This node became the cluster leader - can initiate rebalancing")
|
||||
cm.logger.Printf("This node became the cluster leader - can initiate rebalancing")
|
||||
},
|
||||
OnLostLeader: func() {
|
||||
cm.logger.Printf("📉 This node lost cluster leadership")
|
||||
cm.logger.Printf("This node lost cluster leadership")
|
||||
},
|
||||
OnNewLeader: func(leaderID string) {
|
||||
cm.logger.Printf("🔄 Cluster leadership changed to: %s", leaderID)
|
||||
cm.logger.Printf("Cluster leadership changed to: %s", leaderID)
|
||||
},
|
||||
}
|
||||
|
||||
@@ -71,7 +74,7 @@ func NewClusterManager(nodeID string, natsConn *nats.Conn, ctx context.Context)
|
||||
|
||||
// Start begins cluster management operations
|
||||
func (cm *ClusterManager) Start() {
|
||||
cm.logger.Printf("🚀 Starting cluster manager")
|
||||
cm.logger.Printf("Starting cluster manager")
|
||||
|
||||
// Start leader election
|
||||
cm.election.Start()
|
||||
@@ -88,7 +91,7 @@ func (cm *ClusterManager) Start() {
|
||||
|
||||
// Stop gracefully stops the cluster manager
|
||||
func (cm *ClusterManager) Stop() {
|
||||
cm.logger.Printf("🛑 Stopping cluster manager")
|
||||
cm.logger.Printf("Stopping cluster manager")
|
||||
|
||||
if cm.election != nil {
|
||||
cm.election.Stop()
|
||||
@@ -138,7 +141,7 @@ func (cm *ClusterManager) GetActorsInShard(shardID int) []string {
|
||||
func (cm *ClusterManager) handleClusterMessage(msg *nats.Msg) {
|
||||
var clusterMsg ClusterMessage
|
||||
if err := json.Unmarshal(msg.Data, &clusterMsg); err != nil {
|
||||
cm.logger.Printf("⚠️ Invalid cluster message: %v", err)
|
||||
cm.logger.Printf("Invalid cluster message: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -152,7 +155,7 @@ func (cm *ClusterManager) handleClusterMessage(msg *nats.Msg) {
|
||||
cm.handleNodeUpdate(update)
|
||||
}
|
||||
default:
|
||||
cm.logger.Printf("⚠️ Unknown cluster message type: %s", clusterMsg.Type)
|
||||
cm.logger.Printf("Unknown cluster message type: %s", clusterMsg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -165,12 +168,12 @@ func (cm *ClusterManager) handleNodeUpdate(update NodeUpdate) {
|
||||
case NodeJoined:
|
||||
cm.nodes[update.Node.ID] = update.Node
|
||||
cm.hashRing.AddNode(update.Node.ID)
|
||||
cm.logger.Printf("➕ Node joined: %s", update.Node.ID)
|
||||
cm.logger.Printf("Node joined: %s", update.Node.ID)
|
||||
|
||||
case NodeLeft:
|
||||
delete(cm.nodes, update.Node.ID)
|
||||
cm.hashRing.RemoveNode(update.Node.ID)
|
||||
cm.logger.Printf("➖ Node left: %s", update.Node.ID)
|
||||
cm.logger.Printf("Node left: %s", update.Node.ID)
|
||||
|
||||
case NodeUpdated:
|
||||
if node, exists := cm.nodes[update.Node.ID]; exists {
|
||||
@@ -188,7 +191,7 @@ func (cm *ClusterManager) handleNodeUpdate(update NodeUpdate) {
|
||||
for _, node := range cm.nodes {
|
||||
if now.Sub(node.LastSeen) > 90*time.Second && node.Status != NodeStatusFailed {
|
||||
node.Status = NodeStatusFailed
|
||||
cm.logger.Printf("❌ Node marked as failed: %s (last seen: %s)",
|
||||
cm.logger.Printf("Node marked as failed: %s (last seen: %s)",
|
||||
node.ID, node.LastSeen.Format(time.RFC3339))
|
||||
}
|
||||
}
|
||||
@@ -212,7 +215,7 @@ func (cm *ClusterManager) handleNodeUpdate(update NodeUpdate) {
|
||||
|
||||
// handleRebalanceRequest processes cluster rebalancing requests
|
||||
func (cm *ClusterManager) handleRebalanceRequest(msg ClusterMessage) {
|
||||
cm.logger.Printf("🔄 Handling rebalance request from %s", msg.From)
|
||||
cm.logger.Printf("Handling rebalance request from %s", msg.From)
|
||||
|
||||
// Implementation would handle the specific rebalancing logic
|
||||
// This is a simplified version
|
||||
@@ -220,7 +223,7 @@ func (cm *ClusterManager) handleRebalanceRequest(msg ClusterMessage) {
|
||||
|
||||
// handleMigrationRequest processes actor migration requests
|
||||
func (cm *ClusterManager) handleMigrationRequest(msg ClusterMessage) {
|
||||
cm.logger.Printf("🚚 Handling migration request from %s", msg.From)
|
||||
cm.logger.Printf("Handling migration request from %s", msg.From)
|
||||
|
||||
// Implementation would handle the specific migration logic
|
||||
// This is a simplified version
|
||||
@@ -232,7 +235,7 @@ func (cm *ClusterManager) triggerShardRebalancing(reason string) {
|
||||
return // Only leader can initiate rebalancing
|
||||
}
|
||||
|
||||
cm.logger.Printf("⚖️ Triggering shard rebalancing: %s", reason)
|
||||
cm.logger.Printf("Triggering shard rebalancing: %s", reason)
|
||||
|
||||
// Get active nodes
|
||||
var activeNodes []*NodeInfo
|
||||
@@ -245,12 +248,12 @@ func (cm *ClusterManager) triggerShardRebalancing(reason string) {
|
||||
cm.mutex.RUnlock()
|
||||
|
||||
if len(activeNodes) == 0 {
|
||||
cm.logger.Printf("⚠️ No active nodes available for rebalancing")
|
||||
cm.logger.Printf("No active nodes available for rebalancing")
|
||||
return
|
||||
}
|
||||
|
||||
// This would implement the actual rebalancing logic
|
||||
cm.logger.Printf("🎯 Would rebalance across %d active nodes", len(activeNodes))
|
||||
cm.logger.Printf("Would rebalance across %d active nodes", len(activeNodes))
|
||||
}
|
||||
|
||||
// monitorNodes periodically checks node health and updates
|
||||
@@ -279,7 +282,7 @@ func (cm *ClusterManager) checkNodeHealth() {
|
||||
for _, node := range cm.nodes {
|
||||
if now.Sub(node.LastSeen) > 90*time.Second && node.Status == NodeStatusActive {
|
||||
node.Status = NodeStatusFailed
|
||||
cm.logger.Printf("💔 Node failed: %s", node.ID)
|
||||
cm.logger.Printf("Node failed: %s", node.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,8 +33,26 @@ type ShardManager struct {
|
||||
replication int
|
||||
}
|
||||
|
||||
// NewShardManager creates a new shard manager
|
||||
// NewShardManager creates a new shard manager with default configuration
|
||||
func NewShardManager(shardCount, replication int) *ShardManager {
|
||||
return NewShardManagerWithConfig(ShardConfig{
|
||||
ShardCount: shardCount,
|
||||
ReplicationFactor: replication,
|
||||
})
|
||||
}
|
||||
|
||||
// NewShardManagerWithConfig creates a new shard manager with custom configuration
|
||||
func NewShardManagerWithConfig(config ShardConfig) *ShardManager {
|
||||
// Apply defaults for zero values
|
||||
shardCount := config.ShardCount
|
||||
if shardCount == 0 {
|
||||
shardCount = DefaultNumShards
|
||||
}
|
||||
replication := config.ReplicationFactor
|
||||
if replication == 0 {
|
||||
replication = 1
|
||||
}
|
||||
|
||||
return &ShardManager{
|
||||
shardCount: shardCount,
|
||||
shardMap: &ShardMap{Shards: make(map[int][]string), Nodes: make(map[string]NodeInfo)},
|
||||
@@ -149,6 +167,15 @@ func (sm *ShardManager) GetActorsInShard(shardID int, nodeID string, vmRegistry
|
||||
return actors
|
||||
}
|
||||
|
||||
// GetShardCount returns the total number of shards
|
||||
func (sm *ShardManager) GetShardCount() int {
|
||||
return sm.shardCount
|
||||
}
|
||||
|
||||
// GetReplicationFactor returns the replication factor
|
||||
func (sm *ShardManager) GetReplicationFactor() int {
|
||||
return sm.replication
|
||||
}
|
||||
|
||||
// ConsistentHashPlacement implements PlacementStrategy using consistent hashing
|
||||
type ConsistentHashPlacement struct{}
|
||||
|
||||
103
cluster/types.go
103
cluster/types.go
@@ -4,17 +4,47 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// Default configuration values
|
||||
const (
|
||||
// NumShards defines the total number of shards in the cluster
|
||||
NumShards = 1024
|
||||
// VirtualNodes defines the number of virtual nodes per physical node for consistent hashing
|
||||
VirtualNodes = 150
|
||||
// DefaultNumShards defines the default total number of shards in the cluster
|
||||
DefaultNumShards = 1024
|
||||
// DefaultVirtualNodes defines the default number of virtual nodes per physical node
|
||||
DefaultVirtualNodes = 150
|
||||
// Leadership election constants
|
||||
LeaderLeaseTimeout = 10 * time.Second // How long a leader lease lasts
|
||||
HeartbeatInterval = 3 * time.Second // How often leader sends heartbeats
|
||||
ElectionTimeout = 2 * time.Second // How long to wait for election
|
||||
)
|
||||
|
||||
// HashRingConfig holds configuration for the consistent hash ring
|
||||
type HashRingConfig struct {
|
||||
// VirtualNodes is the number of virtual nodes per physical node (default: 150)
|
||||
VirtualNodes int
|
||||
}
|
||||
|
||||
// DefaultHashRingConfig returns the default hash ring configuration
|
||||
func DefaultHashRingConfig() HashRingConfig {
|
||||
return HashRingConfig{
|
||||
VirtualNodes: DefaultVirtualNodes,
|
||||
}
|
||||
}
|
||||
|
||||
// ShardConfig holds configuration for shard management
|
||||
type ShardConfig struct {
|
||||
// ShardCount is the total number of shards (default: 1024)
|
||||
ShardCount int
|
||||
// ReplicationFactor is the number of replicas per shard (default: 1)
|
||||
ReplicationFactor int
|
||||
}
|
||||
|
||||
// DefaultShardConfig returns the default shard configuration
|
||||
func DefaultShardConfig() ShardConfig {
|
||||
return ShardConfig{
|
||||
ShardCount: DefaultNumShards,
|
||||
ReplicationFactor: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// NodeStatus represents the health status of a node
|
||||
type NodeStatus string
|
||||
|
||||
@@ -108,3 +138,68 @@ type LeadershipLease struct {
|
||||
StartedAt time.Time `json:"startedAt"`
|
||||
}
|
||||
|
||||
// VirtualMachine defines the interface for a virtual machine instance.
|
||||
// This interface provides the minimal contract needed by the cluster package
|
||||
// to interact with VMs without creating import cycles with the runtime package.
|
||||
type VirtualMachine interface {
|
||||
// GetID returns the unique identifier for this VM
|
||||
GetID() string
|
||||
// GetActorID returns the actor ID this VM represents
|
||||
GetActorID() string
|
||||
// GetState returns the current state of the VM
|
||||
GetState() VMState
|
||||
}
|
||||
|
||||
// VMState represents the state of a virtual machine
|
||||
type VMState string
|
||||
|
||||
const (
|
||||
VMStateIdle VMState = "idle"
|
||||
VMStateRunning VMState = "running"
|
||||
VMStatePaused VMState = "paused"
|
||||
VMStateStopped VMState = "stopped"
|
||||
)
|
||||
|
||||
// RuntimeModel defines the interface for an EventStorming model that can be loaded into a runtime.
|
||||
// This decouples the cluster package from the specific eventstorming package.
|
||||
type RuntimeModel interface {
|
||||
// GetID returns the unique identifier for this model
|
||||
GetID() string
|
||||
// GetName returns the name of this model
|
||||
GetName() string
|
||||
}
|
||||
|
||||
// RuntimeMessage defines the interface for messages that can be sent through the runtime.
|
||||
// This provides type safety for inter-actor communication without creating import cycles.
|
||||
type RuntimeMessage interface {
|
||||
// GetTargetActorID returns the ID of the actor this message is addressed to
|
||||
GetTargetActorID() string
|
||||
// GetType returns the message type identifier
|
||||
GetType() string
|
||||
}
|
||||
|
||||
// ModelPayload is a concrete type for JSON-unmarshaling RuntimeModel payloads.
|
||||
// Use this when receiving model data over the network.
|
||||
type ModelPayload struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// GetID implements RuntimeModel
|
||||
func (m *ModelPayload) GetID() string { return m.ID }
|
||||
|
||||
// GetName implements RuntimeModel
|
||||
func (m *ModelPayload) GetName() string { return m.Name }
|
||||
|
||||
// MessagePayload is a concrete type for JSON-unmarshaling RuntimeMessage payloads.
|
||||
// Use this when receiving message data over the network.
|
||||
type MessagePayload struct {
|
||||
TargetActorID string `json:"targetActorId"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// GetTargetActorID implements RuntimeMessage
|
||||
func (m *MessagePayload) GetTargetActorID() string { return m.TargetActorID }
|
||||
|
||||
// GetType implements RuntimeMessage
|
||||
func (m *MessagePayload) GetType() string { return m.Type }
|
||||
|
||||
45
event.go
45
event.go
@@ -28,6 +28,39 @@ func (e *VersionConflictError) Unwrap() error {
|
||||
return ErrVersionConflict
|
||||
}
|
||||
|
||||
// ReplayError captures information about a malformed event encountered during replay.
|
||||
// This allows callers to inspect and handle corrupted data without losing context.
|
||||
type ReplayError struct {
|
||||
// SequenceNumber is the sequence number of the message in the stream (if available)
|
||||
SequenceNumber uint64
|
||||
// RawData is the raw bytes that could not be unmarshaled
|
||||
RawData []byte
|
||||
// Err is the underlying unmarshal error
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ReplayError) Error() string {
|
||||
return fmt.Sprintf("failed to unmarshal event at sequence %d: %v", e.SequenceNumber, e.Err)
|
||||
}
|
||||
|
||||
func (e *ReplayError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
// ReplayResult contains the results of replaying events, including any errors encountered.
|
||||
// This allows callers to decide how to handle malformed events rather than silently skipping them.
|
||||
type ReplayResult struct {
|
||||
// Events contains the successfully unmarshaled events
|
||||
Events []*Event
|
||||
// Errors contains information about any malformed events encountered
|
||||
Errors []ReplayError
|
||||
}
|
||||
|
||||
// HasErrors returns true if any malformed events were encountered during replay
|
||||
func (r *ReplayResult) HasErrors() bool {
|
||||
return len(r.Errors) > 0
|
||||
}
|
||||
|
||||
// Event represents a domain event in the system
|
||||
type Event struct {
|
||||
ID string `json:"id"`
|
||||
@@ -174,6 +207,18 @@ type EventStore interface {
|
||||
GetLatestVersion(actorID string) (int64, error)
|
||||
}
|
||||
|
||||
// EventStoreWithErrors extends EventStore with methods that report malformed events.
|
||||
// Stores that may encounter corrupted data during replay (e.g., JetStream) should
|
||||
// implement this interface to give callers visibility into data quality issues.
|
||||
type EventStoreWithErrors interface {
|
||||
EventStore
|
||||
|
||||
// GetEventsWithErrors retrieves events for an actor and reports any malformed
|
||||
// events encountered. This method allows callers to decide how to handle
|
||||
// corrupted data rather than silently skipping it.
|
||||
GetEventsWithErrors(actorID string, fromVersion int64) (*ReplayResult, error)
|
||||
}
|
||||
|
||||
// SnapshotStore extends EventStore with snapshot capabilities
|
||||
type SnapshotStore interface {
|
||||
EventStore
|
||||
|
||||
127
event_test.go
127
event_test.go
@@ -1208,3 +1208,130 @@ func TestEvent_MetadataAllHelpersRoundTrip(t *testing.T) {
|
||||
t.Errorf("GetSpanID mismatch: got %q", decoded.GetSpanID())
|
||||
}
|
||||
}
|
||||
|
||||
// Tests for ReplayError and ReplayResult types
|
||||
|
||||
func TestReplayError_Error(t *testing.T) {
|
||||
err := &ReplayError{
|
||||
SequenceNumber: 42,
|
||||
RawData: []byte(`invalid json`),
|
||||
Err: json.Unmarshal([]byte(`{`), &struct{}{}),
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
if !strings.Contains(errMsg, "42") {
|
||||
t.Errorf("expected error message to contain sequence number, got: %s", errMsg)
|
||||
}
|
||||
if !strings.Contains(errMsg, "unmarshal") || !strings.Contains(errMsg, "failed") {
|
||||
t.Errorf("expected error message to contain 'failed' and 'unmarshal', got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayError_Unwrap(t *testing.T) {
|
||||
innerErr := json.Unmarshal([]byte(`{`), &struct{}{})
|
||||
err := &ReplayError{
|
||||
SequenceNumber: 1,
|
||||
RawData: []byte(`{`),
|
||||
Err: innerErr,
|
||||
}
|
||||
|
||||
unwrapped := err.Unwrap()
|
||||
if unwrapped != innerErr {
|
||||
t.Errorf("expected Unwrap to return inner error")
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayResult_HasErrors(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
result *ReplayResult
|
||||
expected bool
|
||||
}{
|
||||
{
|
||||
name: "no errors",
|
||||
result: &ReplayResult{Events: []*Event{}, Errors: []ReplayError{}},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "nil errors slice",
|
||||
result: &ReplayResult{Events: []*Event{}, Errors: nil},
|
||||
expected: false,
|
||||
},
|
||||
{
|
||||
name: "has errors",
|
||||
result: &ReplayResult{
|
||||
Events: []*Event{},
|
||||
Errors: []ReplayError{
|
||||
{SequenceNumber: 1, RawData: []byte(`bad`), Err: nil},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
{
|
||||
name: "has events and errors",
|
||||
result: &ReplayResult{
|
||||
Events: []*Event{{ID: "evt-1"}},
|
||||
Errors: []ReplayError{
|
||||
{SequenceNumber: 2, RawData: []byte(`bad`), Err: nil},
|
||||
},
|
||||
},
|
||||
expected: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := tt.result.HasErrors(); got != tt.expected {
|
||||
t.Errorf("HasErrors() = %v, want %v", got, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayResult_EmptyResult(t *testing.T) {
|
||||
result := &ReplayResult{
|
||||
Events: []*Event{},
|
||||
Errors: []ReplayError{},
|
||||
}
|
||||
|
||||
if result.HasErrors() {
|
||||
t.Error("expected HasErrors() to return false for empty result")
|
||||
}
|
||||
if len(result.Events) != 0 {
|
||||
t.Errorf("expected 0 events, got %d", len(result.Events))
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayError_WithZeroSequence(t *testing.T) {
|
||||
err := &ReplayError{
|
||||
SequenceNumber: 0,
|
||||
RawData: []byte(`corrupted`),
|
||||
Err: json.Unmarshal([]byte(`not-json`), &struct{}{}),
|
||||
}
|
||||
|
||||
errMsg := err.Error()
|
||||
if !strings.Contains(errMsg, "sequence 0") {
|
||||
t.Errorf("expected error message to contain 'sequence 0', got: %s", errMsg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReplayError_WithLargeRawData(t *testing.T) {
|
||||
largeData := make([]byte, 1024*1024) // 1MB
|
||||
for i := range largeData {
|
||||
largeData[i] = 'x'
|
||||
}
|
||||
|
||||
err := &ReplayError{
|
||||
SequenceNumber: 999,
|
||||
RawData: largeData,
|
||||
Err: json.Unmarshal(largeData, &struct{}{}),
|
||||
}
|
||||
|
||||
// Should be able to create the error without issues
|
||||
if len(err.RawData) != 1024*1024 {
|
||||
t.Errorf("expected RawData to be preserved, got length %d", len(err.RawData))
|
||||
}
|
||||
|
||||
// Error() should still work
|
||||
_ = err.Error()
|
||||
}
|
||||
|
||||
46
store/config_test.go
Normal file
46
store/config_test.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestDefaultJetStreamConfig(t *testing.T) {
|
||||
config := DefaultJetStreamConfig()
|
||||
|
||||
if config.StreamRetention != DefaultStreamRetention {
|
||||
t.Errorf("expected StreamRetention=%v, got %v", DefaultStreamRetention, config.StreamRetention)
|
||||
}
|
||||
if config.ReplicaCount != DefaultReplicaCount {
|
||||
t.Errorf("expected ReplicaCount=%d, got %d", DefaultReplicaCount, config.ReplicaCount)
|
||||
}
|
||||
}
|
||||
|
||||
func TestJetStreamConfigDefaults(t *testing.T) {
|
||||
t.Run("default stream retention is 1 year", func(t *testing.T) {
|
||||
expected := 365 * 24 * time.Hour
|
||||
if DefaultStreamRetention != expected {
|
||||
t.Errorf("expected DefaultStreamRetention=%v, got %v", expected, DefaultStreamRetention)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("default replica count is 1", func(t *testing.T) {
|
||||
if DefaultReplicaCount != 1 {
|
||||
t.Errorf("expected DefaultReplicaCount=1, got %d", DefaultReplicaCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestJetStreamConfigCustomValues(t *testing.T) {
|
||||
config := JetStreamConfig{
|
||||
StreamRetention: 30 * 24 * time.Hour, // 30 days
|
||||
ReplicaCount: 3,
|
||||
}
|
||||
|
||||
if config.StreamRetention != 30*24*time.Hour {
|
||||
t.Errorf("expected StreamRetention=30 days, got %v", config.StreamRetention)
|
||||
}
|
||||
if config.ReplicaCount != 3 {
|
||||
t.Errorf("expected ReplicaCount=3, got %d", config.ReplicaCount)
|
||||
}
|
||||
}
|
||||
@@ -11,29 +11,66 @@ import (
|
||||
"github.com/nats-io/nats.go"
|
||||
)
|
||||
|
||||
// JetStreamEventStore implements EventStore using NATS JetStream for persistence
|
||||
// Default configuration values for JetStream event store
|
||||
const (
|
||||
DefaultStreamRetention = 365 * 24 * time.Hour // 1 year
|
||||
DefaultReplicaCount = 1
|
||||
)
|
||||
|
||||
// JetStreamConfig holds configuration options for JetStreamEventStore
|
||||
type JetStreamConfig struct {
|
||||
// StreamRetention is how long to keep events (default: 1 year)
|
||||
StreamRetention time.Duration
|
||||
// ReplicaCount is the number of replicas for high availability (default: 1)
|
||||
ReplicaCount int
|
||||
}
|
||||
|
||||
// DefaultJetStreamConfig returns the default configuration
|
||||
func DefaultJetStreamConfig() JetStreamConfig {
|
||||
return JetStreamConfig{
|
||||
StreamRetention: DefaultStreamRetention,
|
||||
ReplicaCount: DefaultReplicaCount,
|
||||
}
|
||||
}
|
||||
|
||||
// JetStreamEventStore implements EventStore using NATS JetStream for persistence.
|
||||
// It also implements EventStoreWithErrors to report malformed events during replay.
|
||||
type JetStreamEventStore struct {
|
||||
js nats.JetStreamContext
|
||||
streamName string
|
||||
config JetStreamConfig
|
||||
mu sync.Mutex // Protects version checks during SaveEvent
|
||||
versions map[string]int64 // actorID -> latest version cache
|
||||
}
|
||||
|
||||
// NewJetStreamEventStore creates a new JetStream-based event store
|
||||
// NewJetStreamEventStore creates a new JetStream-based event store with default configuration
|
||||
func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) {
|
||||
return NewJetStreamEventStoreWithConfig(natsConn, streamName, DefaultJetStreamConfig())
|
||||
}
|
||||
|
||||
// NewJetStreamEventStoreWithConfig creates a new JetStream-based event store with custom configuration
|
||||
func NewJetStreamEventStoreWithConfig(natsConn *nats.Conn, streamName string, config JetStreamConfig) (*JetStreamEventStore, error) {
|
||||
js, err := natsConn.JetStream()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get JetStream context: %w", err)
|
||||
}
|
||||
|
||||
// Apply defaults for zero values
|
||||
if config.StreamRetention == 0 {
|
||||
config.StreamRetention = DefaultStreamRetention
|
||||
}
|
||||
if config.ReplicaCount == 0 {
|
||||
config.ReplicaCount = DefaultReplicaCount
|
||||
}
|
||||
|
||||
// Create or update the stream
|
||||
stream := &nats.StreamConfig{
|
||||
Name: streamName,
|
||||
Subjects: []string{fmt.Sprintf("%s.events.>", streamName), fmt.Sprintf("%s.snapshots.>", streamName)},
|
||||
Storage: nats.FileStorage,
|
||||
Retention: nats.LimitsPolicy,
|
||||
MaxAge: 365 * 24 * time.Hour, // Keep events for 1 year
|
||||
Replicas: 1, // Can be increased for HA
|
||||
MaxAge: config.StreamRetention,
|
||||
Replicas: config.ReplicaCount,
|
||||
}
|
||||
|
||||
_, err = js.AddStream(stream)
|
||||
@@ -44,6 +81,7 @@ func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamE
|
||||
return &JetStreamEventStore{
|
||||
js: js,
|
||||
streamName: streamName,
|
||||
config: config,
|
||||
versions: make(map[string]int64),
|
||||
}, nil
|
||||
}
|
||||
@@ -102,18 +140,18 @@ func (jes *JetStreamEventStore) getLatestVersionLocked(actorID string) (int64, e
|
||||
return version, nil
|
||||
}
|
||||
|
||||
// Fetch from JetStream
|
||||
events, err := jes.getEventsInternal(actorID, 0)
|
||||
// Fetch from JetStream - use internal method that returns result
|
||||
result, err := jes.getEventsWithErrorsInternal(actorID, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(events) == 0 {
|
||||
if len(result.Events) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
latestVersion := int64(0)
|
||||
for _, event := range events {
|
||||
for _, event := range result.Events {
|
||||
if event.Version > latestVersion {
|
||||
latestVersion = event.Version
|
||||
}
|
||||
@@ -125,13 +163,27 @@ func (jes *JetStreamEventStore) getLatestVersionLocked(actorID string) (int64, e
|
||||
return latestVersion, nil
|
||||
}
|
||||
|
||||
// GetEvents retrieves all events for an actor since a version
|
||||
// GetEvents retrieves all events for an actor since a version.
|
||||
// Note: This method silently skips malformed events for backward compatibility.
|
||||
// Use GetEventsWithErrors to receive information about malformed events.
|
||||
func (jes *JetStreamEventStore) GetEvents(actorID string, fromVersion int64) ([]*aether.Event, error) {
|
||||
return jes.getEventsInternal(actorID, fromVersion)
|
||||
result, err := jes.getEventsWithErrorsInternal(actorID, fromVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return result.Events, nil
|
||||
}
|
||||
|
||||
// getEventsInternal is the internal implementation of GetEvents
|
||||
func (jes *JetStreamEventStore) getEventsInternal(actorID string, fromVersion int64) ([]*aether.Event, error) {
|
||||
// GetEventsWithErrors retrieves events for an actor and reports any malformed
|
||||
// events encountered. This method allows callers to decide how to handle
|
||||
// corrupted data rather than silently skipping it.
|
||||
func (jes *JetStreamEventStore) GetEventsWithErrors(actorID string, fromVersion int64) (*aether.ReplayResult, error) {
|
||||
return jes.getEventsWithErrorsInternal(actorID, fromVersion)
|
||||
}
|
||||
|
||||
// getEventsWithErrorsInternal is the internal implementation that tracks both
|
||||
// successfully parsed events and errors for malformed events.
|
||||
func (jes *JetStreamEventStore) getEventsWithErrorsInternal(actorID string, fromVersion int64) (*aether.ReplayResult, error) {
|
||||
// Create subject filter for this actor
|
||||
subject := fmt.Sprintf("%s.events.%s.%s",
|
||||
jes.streamName,
|
||||
@@ -145,7 +197,10 @@ func (jes *JetStreamEventStore) getEventsInternal(actorID string, fromVersion in
|
||||
}
|
||||
defer consumer.Unsubscribe()
|
||||
|
||||
var events []*aether.Event
|
||||
result := &aether.ReplayResult{
|
||||
Events: make([]*aether.Event, 0),
|
||||
Errors: make([]aether.ReplayError, 0),
|
||||
}
|
||||
|
||||
// Fetch messages in batches
|
||||
for {
|
||||
@@ -160,12 +215,24 @@ func (jes *JetStreamEventStore) getEventsInternal(actorID string, fromVersion in
|
||||
for _, msg := range msgs {
|
||||
var event aether.Event
|
||||
if err := json.Unmarshal(msg.Data, &event); err != nil {
|
||||
continue // Skip malformed events
|
||||
// Record the error with context instead of silently skipping
|
||||
metadata, _ := msg.Metadata()
|
||||
seqNum := uint64(0)
|
||||
if metadata != nil {
|
||||
seqNum = metadata.Sequence.Stream
|
||||
}
|
||||
result.Errors = append(result.Errors, aether.ReplayError{
|
||||
SequenceNumber: seqNum,
|
||||
RawData: msg.Data,
|
||||
Err: err,
|
||||
})
|
||||
msg.Ack() // Still ack to prevent redelivery
|
||||
continue
|
||||
}
|
||||
|
||||
// Filter by version
|
||||
if event.Version > fromVersion {
|
||||
events = append(events, &event)
|
||||
result.Events = append(result.Events, &event)
|
||||
}
|
||||
|
||||
msg.Ack()
|
||||
@@ -176,7 +243,7 @@ func (jes *JetStreamEventStore) getEventsInternal(actorID string, fromVersion in
|
||||
}
|
||||
}
|
||||
|
||||
return events, nil
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// GetLatestVersion returns the latest version for an actor
|
||||
@@ -279,3 +346,6 @@ func sanitizeSubject(s string) string {
|
||||
s = strings.ReplaceAll(s, ">", "_")
|
||||
return s
|
||||
}
|
||||
|
||||
// Compile-time check that JetStreamEventStore implements EventStoreWithErrors
|
||||
var _ aether.EventStoreWithErrors = (*JetStreamEventStore)(nil)
|
||||
|
||||
Reference in New Issue
Block a user