package store import ( "encoding/json" "fmt" "strings" "sync" "time" "git.flowmade.one/flowmade-one/aether" "github.com/nats-io/nats.go" ) // Default configuration values for JetStream event store const ( DefaultStreamRetention = 365 * 24 * time.Hour // 1 year DefaultReplicaCount = 1 ) // JetStreamConfig holds configuration options for JetStreamEventStore type JetStreamConfig struct { // StreamRetention is how long to keep events (default: 1 year) StreamRetention time.Duration // ReplicaCount is the number of replicas for high availability (default: 1) ReplicaCount int } // DefaultJetStreamConfig returns the default configuration func DefaultJetStreamConfig() JetStreamConfig { return JetStreamConfig{ StreamRetention: DefaultStreamRetention, ReplicaCount: DefaultReplicaCount, } } // JetStreamEventStore implements EventStore using NATS JetStream for persistence type JetStreamEventStore struct { js nats.JetStreamContext streamName string config JetStreamConfig mu sync.Mutex // Protects version checks during SaveEvent versions map[string]int64 // actorID -> latest version cache } // NewJetStreamEventStore creates a new JetStream-based event store with default configuration func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) { return NewJetStreamEventStoreWithConfig(natsConn, streamName, DefaultJetStreamConfig()) } // NewJetStreamEventStoreWithConfig creates a new JetStream-based event store with custom configuration func NewJetStreamEventStoreWithConfig(natsConn *nats.Conn, streamName string, config JetStreamConfig) (*JetStreamEventStore, error) { js, err := natsConn.JetStream() if err != nil { return nil, fmt.Errorf("failed to get JetStream context: %w", err) } // Apply defaults for zero values if config.StreamRetention == 0 { config.StreamRetention = DefaultStreamRetention } if config.ReplicaCount == 0 { config.ReplicaCount = DefaultReplicaCount } // Create or update the stream stream := &nats.StreamConfig{ Name: streamName, Subjects: []string{fmt.Sprintf("%s.events.>", streamName), fmt.Sprintf("%s.snapshots.>", streamName)}, Storage: nats.FileStorage, Retention: nats.LimitsPolicy, MaxAge: config.StreamRetention, Replicas: config.ReplicaCount, } _, err = js.AddStream(stream) if err != nil && !strings.Contains(err.Error(), "already exists") { return nil, fmt.Errorf("failed to create stream: %w", err) } return &JetStreamEventStore{ js: js, streamName: streamName, config: config, versions: make(map[string]int64), }, nil } // SaveEvent persists an event to JetStream. // Returns VersionConflictError if the event's version is not strictly greater // than the current latest version for the actor. func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error { jes.mu.Lock() defer jes.mu.Unlock() // Get current latest version for this actor currentVersion, err := jes.getLatestVersionLocked(event.ActorID) if err != nil { return fmt.Errorf("failed to get latest version: %w", err) } // Validate version is strictly greater than current if event.Version <= currentVersion { return &aether.VersionConflictError{ ActorID: event.ActorID, AttemptedVersion: event.Version, CurrentVersion: currentVersion, } } // Serialize event to JSON data, err := json.Marshal(event) if err != nil { return fmt.Errorf("failed to marshal event: %w", err) } // Create subject: stream.events.actorType.actorID subject := fmt.Sprintf("%s.events.%s.%s", jes.streamName, sanitizeSubject(extractActorType(event.ActorID)), sanitizeSubject(event.ActorID)) // Publish with event ID as message ID for deduplication _, err = jes.js.Publish(subject, data, nats.MsgId(event.ID)) if err != nil { return fmt.Errorf("failed to publish event to JetStream: %w", err) } // Update version cache jes.versions[event.ActorID] = event.Version return nil } // getLatestVersionLocked returns the latest version for an actor. // Caller must hold jes.mu. func (jes *JetStreamEventStore) getLatestVersionLocked(actorID string) (int64, error) { // Check cache first if version, ok := jes.versions[actorID]; ok { return version, nil } // Fetch from JetStream events, err := jes.getEventsInternal(actorID, 0) if err != nil { return 0, err } if len(events) == 0 { return 0, nil } latestVersion := int64(0) for _, event := range events { if event.Version > latestVersion { latestVersion = event.Version } } // Update cache jes.versions[actorID] = latestVersion return latestVersion, nil } // GetEvents retrieves all events for an actor since a version func (jes *JetStreamEventStore) GetEvents(actorID string, fromVersion int64) ([]*aether.Event, error) { return jes.getEventsInternal(actorID, fromVersion) } // getEventsInternal is the internal implementation of GetEvents func (jes *JetStreamEventStore) getEventsInternal(actorID string, fromVersion int64) ([]*aether.Event, error) { // Create subject filter for this actor subject := fmt.Sprintf("%s.events.%s.%s", jes.streamName, sanitizeSubject(extractActorType(actorID)), sanitizeSubject(actorID)) // Create consumer to read events consumer, err := jes.js.PullSubscribe(subject, "") if err != nil { return nil, fmt.Errorf("failed to create consumer: %w", err) } defer consumer.Unsubscribe() var events []*aether.Event // Fetch messages in batches for { msgs, err := consumer.Fetch(100, nats.MaxWait(time.Second)) if err != nil { if err == nats.ErrTimeout { break // No more messages } return nil, fmt.Errorf("failed to fetch messages: %w", err) } for _, msg := range msgs { var event aether.Event if err := json.Unmarshal(msg.Data, &event); err != nil { continue // Skip malformed events } // Filter by version if event.Version > fromVersion { events = append(events, &event) } msg.Ack() } if len(msgs) < 100 { break // No more messages } } return events, nil } // GetLatestVersion returns the latest version for an actor func (jes *JetStreamEventStore) GetLatestVersion(actorID string) (int64, error) { events, err := jes.GetEvents(actorID, 0) if err != nil { return 0, err } if len(events) == 0 { return 0, nil } latestVersion := int64(0) for _, event := range events { if event.Version > latestVersion { latestVersion = event.Version } } return latestVersion, nil } // GetLatestSnapshot gets the most recent snapshot for an actor func (jes *JetStreamEventStore) GetLatestSnapshot(actorID string) (*aether.ActorSnapshot, error) { // Create subject for snapshots subject := fmt.Sprintf("%s.snapshots.%s.%s", jes.streamName, sanitizeSubject(extractActorType(actorID)), sanitizeSubject(actorID)) // Try to get the latest snapshot consumer, err := jes.js.PullSubscribe(subject, "", nats.DeliverLast()) if err != nil { return nil, fmt.Errorf("failed to create snapshot consumer: %w", err) } defer consumer.Unsubscribe() msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second)) if err != nil { if err == nats.ErrTimeout { return nil, fmt.Errorf("no snapshot found for actor %s", actorID) } return nil, fmt.Errorf("failed to fetch snapshot: %w", err) } if len(msgs) == 0 { return nil, fmt.Errorf("no snapshot found for actor %s", actorID) } var snapshot aether.ActorSnapshot if err := json.Unmarshal(msgs[0].Data, &snapshot); err != nil { return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err) } msgs[0].Ack() return &snapshot, nil } // SaveSnapshot saves a snapshot of actor state func (jes *JetStreamEventStore) SaveSnapshot(snapshot *aether.ActorSnapshot) error { // Serialize snapshot to JSON data, err := json.Marshal(snapshot) if err != nil { return fmt.Errorf("failed to marshal snapshot: %w", err) } // Create subject for snapshots subject := fmt.Sprintf("%s.snapshots.%s.%s", jes.streamName, sanitizeSubject(extractActorType(snapshot.ActorID)), sanitizeSubject(snapshot.ActorID)) // Publish snapshot _, err = jes.js.Publish(subject, data) if err != nil { return fmt.Errorf("failed to publish snapshot to JetStream: %w", err) } return nil } // Helper functions // extractActorType extracts the actor type from an actor ID func extractActorType(actorID string) string { for i, c := range actorID { if c == '-' && i > 0 { return actorID[:i] } } return "unknown" } // sanitizeSubject sanitizes a string for use in NATS subjects func sanitizeSubject(s string) string { s = strings.ReplaceAll(s, " ", "_") s = strings.ReplaceAll(s, ".", "_") s = strings.ReplaceAll(s, "*", "_") s = strings.ReplaceAll(s, ">", "_") return s }