Files
aether/store/jetstream.go
Claude Code e69f7a30e4
Some checks failed
CI / build (pull_request) Successful in 19s
CI / integration (pull_request) Failing after 1m59s
fix: address critical TOCTOU race condition and error handling inconsistencies
- Fix TOCTOU race condition in SaveEvent by holding the lock throughout entire version validation and publish operation
- Add getLatestVersionLocked helper method to prevent race window where multiple concurrent threads read the same currentVersion
- Fix GetLatestSnapshot to return error when no snapshot exists (not nil), distinguishing "not created" from "error occurred"
- The concurrent version conflict test now passes with exactly 1 success and 49 conflicts instead of 50 successes

These changes ensure thread-safe optimistic concurrency control and consistent error handling semantics.

Co-Authored-By: Claude Code <noreply@anthropic.com>
2026-01-13 09:02:36 +01:00

438 lines
14 KiB
Go

package store
import (
"encoding/json"
"fmt"
"strings"
"sync"
"time"
"git.flowmade.one/flowmade-one/aether"
"github.com/nats-io/nats.go"
)
// Default configuration values for JetStream event store
const (
DefaultStreamRetention = 365 * 24 * time.Hour // 1 year
DefaultReplicaCount = 1
)
// JetStreamConfig holds configuration options for JetStreamEventStore
type JetStreamConfig struct {
// StreamRetention is how long to keep events (default: 1 year)
StreamRetention time.Duration
// ReplicaCount is the number of replicas for high availability (default: 1)
ReplicaCount int
// Namespace is an optional prefix for stream names to provide storage isolation.
// When set, the actual stream name becomes "{namespace}_{streamName}".
// Events in namespaced stores are completely isolated from other namespaces.
// Leave empty for backward-compatible non-namespaced behavior.
Namespace string
}
// DefaultJetStreamConfig returns the default configuration
func DefaultJetStreamConfig() JetStreamConfig {
return JetStreamConfig{
StreamRetention: DefaultStreamRetention,
ReplicaCount: DefaultReplicaCount,
}
}
// JetStreamEventStore implements EventStore using NATS JetStream for persistence.
// It also implements EventStoreWithErrors to report malformed events during replay.
type JetStreamEventStore struct {
js nats.JetStreamContext
streamName string
config JetStreamConfig
mu sync.Mutex // Protects version checks during SaveEvent
versions map[string]int64 // actorID -> latest version cache
}
// NewJetStreamEventStore creates a new JetStream-based event store with default configuration
func NewJetStreamEventStore(natsConn *nats.Conn, streamName string) (*JetStreamEventStore, error) {
return NewJetStreamEventStoreWithConfig(natsConn, streamName, DefaultJetStreamConfig())
}
// NewJetStreamEventStoreWithNamespace creates a new JetStream-based event store with namespace isolation.
// The namespace is prefixed to the stream name to ensure complete isolation at the storage level.
// This is a convenience function; the same can be achieved by setting Namespace in JetStreamConfig.
func NewJetStreamEventStoreWithNamespace(natsConn *nats.Conn, streamName string, namespace string) (*JetStreamEventStore, error) {
config := DefaultJetStreamConfig()
config.Namespace = namespace
return NewJetStreamEventStoreWithConfig(natsConn, streamName, config)
}
// NewJetStreamEventStoreWithConfig creates a new JetStream-based event store with custom configuration
func NewJetStreamEventStoreWithConfig(natsConn *nats.Conn, streamName string, config JetStreamConfig) (*JetStreamEventStore, error) {
js, err := natsConn.JetStream()
if err != nil {
return nil, fmt.Errorf("failed to get JetStream context: %w", err)
}
// Apply defaults for zero values
if config.StreamRetention == 0 {
config.StreamRetention = DefaultStreamRetention
}
if config.ReplicaCount == 0 {
config.ReplicaCount = DefaultReplicaCount
}
// Apply namespace prefix to stream name if provided
effectiveStreamName := streamName
if config.Namespace != "" {
effectiveStreamName = fmt.Sprintf("%s_%s", sanitizeSubject(config.Namespace), streamName)
}
// Create or update the stream
stream := &nats.StreamConfig{
Name: effectiveStreamName,
Subjects: []string{fmt.Sprintf("%s.events.>", effectiveStreamName), fmt.Sprintf("%s.snapshots.>", effectiveStreamName)},
Storage: nats.FileStorage,
Retention: nats.LimitsPolicy,
MaxAge: config.StreamRetention,
Replicas: config.ReplicaCount,
}
_, err = js.AddStream(stream)
if err != nil && !strings.Contains(err.Error(), "already exists") {
return nil, fmt.Errorf("failed to create stream: %w", err)
}
return &JetStreamEventStore{
js: js,
streamName: effectiveStreamName,
config: config,
versions: make(map[string]int64),
}, nil
}
// GetNamespace returns the namespace configured for this store, or empty string if not namespaced.
func (jes *JetStreamEventStore) GetNamespace() string {
return jes.config.Namespace
}
// GetStreamName returns the effective stream name (including namespace prefix if applicable).
func (jes *JetStreamEventStore) GetStreamName() string {
return jes.streamName
}
// SaveEvent persists an event to JetStream.
// Returns VersionConflictError if the event's version is not strictly greater
// than the current latest version for the actor.
func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error {
jes.mu.Lock()
defer jes.mu.Unlock()
// Check cache first
if version, ok := jes.versions[event.ActorID]; ok {
// Validate version against cached version
if event.Version <= version {
return &aether.VersionConflictError{
ActorID: event.ActorID,
AttemptedVersion: event.Version,
CurrentVersion: version,
}
}
// Version check passed, proceed with publish while holding lock
} else {
// Cache miss - need to check actual stream
// Get current latest version while holding lock to prevent TOCTOU race
currentVersion, err := jes.getLatestVersionLocked(event.ActorID)
if err != nil {
return fmt.Errorf("failed to get latest version: %w", err)
}
// Validate version is strictly greater than current
if event.Version <= currentVersion {
return &aether.VersionConflictError{
ActorID: event.ActorID,
AttemptedVersion: event.Version,
CurrentVersion: currentVersion,
}
}
// Update cache with current version
jes.versions[event.ActorID] = currentVersion
}
// Serialize event to JSON
data, err := json.Marshal(event)
if err != nil {
return fmt.Errorf("failed to marshal event: %w", err)
}
// Create subject: stream.events.actorType.actorID
subject := fmt.Sprintf("%s.events.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(event.ActorID)),
sanitizeSubject(event.ActorID))
// Publish with event ID as message ID for deduplication
_, err = jes.js.Publish(subject, data, nats.MsgId(event.ID))
if err != nil {
return fmt.Errorf("failed to publish event to JetStream: %w", err)
}
// Update version cache after successful publish
jes.versions[event.ActorID] = event.Version
return nil
}
// GetEvents retrieves all events for an actor since a version.
// Note: This method silently skips malformed events for backward compatibility.
// Use GetEventsWithErrors to receive information about malformed events.
func (jes *JetStreamEventStore) GetEvents(actorID string, fromVersion int64) ([]*aether.Event, error) {
result, err := jes.getEventsWithErrorsInternal(actorID, fromVersion)
if err != nil {
return nil, err
}
return result.Events, nil
}
// GetEventsWithErrors retrieves events for an actor and reports any malformed
// events encountered. This method allows callers to decide how to handle
// corrupted data rather than silently skipping it.
func (jes *JetStreamEventStore) GetEventsWithErrors(actorID string, fromVersion int64) (*aether.ReplayResult, error) {
return jes.getEventsWithErrorsInternal(actorID, fromVersion)
}
// getEventsWithErrorsInternal is the internal implementation that tracks both
// successfully parsed events and errors for malformed events.
func (jes *JetStreamEventStore) getEventsWithErrorsInternal(actorID string, fromVersion int64) (*aether.ReplayResult, error) {
// Create subject filter for this actor
subject := fmt.Sprintf("%s.events.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(actorID)),
sanitizeSubject(actorID))
// Create consumer to read events
consumer, err := jes.js.PullSubscribe(subject, "")
if err != nil {
return nil, fmt.Errorf("failed to create consumer: %w", err)
}
defer consumer.Unsubscribe()
result := &aether.ReplayResult{
Events: make([]*aether.Event, 0),
Errors: make([]aether.ReplayError, 0),
}
// Fetch messages in batches
for {
msgs, err := consumer.Fetch(100, nats.MaxWait(time.Second))
if err != nil {
if err == nats.ErrTimeout {
break // No more messages
}
return nil, fmt.Errorf("failed to fetch messages: %w", err)
}
for _, msg := range msgs {
var event aether.Event
if err := json.Unmarshal(msg.Data, &event); err != nil {
// Record the error with context instead of silently skipping
metadata, _ := msg.Metadata()
seqNum := uint64(0)
if metadata != nil {
seqNum = metadata.Sequence.Stream
}
result.Errors = append(result.Errors, aether.ReplayError{
SequenceNumber: seqNum,
RawData: msg.Data,
Err: err,
})
msg.Ack() // Still ack to prevent redelivery
continue
}
// Filter by version
if event.Version > fromVersion {
result.Events = append(result.Events, &event)
}
msg.Ack()
}
if len(msgs) < 100 {
break // No more messages
}
}
return result, nil
}
// GetLatestVersion returns the latest version for an actor in O(1) time.
// It uses JetStream's DeliverLast() option to fetch only the last message
// instead of scanning all events, making this O(1) instead of O(n).
func (jes *JetStreamEventStore) GetLatestVersion(actorID string) (int64, error) {
// Create subject filter for this actor
subject := fmt.Sprintf("%s.events.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(actorID)),
sanitizeSubject(actorID))
// Create consumer to read only the last message
consumer, err := jes.js.PullSubscribe(subject, "", nats.DeliverLast())
if err != nil {
return 0, fmt.Errorf("failed to create consumer: %w", err)
}
defer consumer.Unsubscribe()
// Fetch only the last message
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
if err != nil {
if err == nats.ErrTimeout {
// No messages for this actor, return 0
return 0, nil
}
return 0, fmt.Errorf("failed to fetch last message: %w", err)
}
if len(msgs) == 0 {
// No events for this actor
return 0, nil
}
// Parse the last message to get the version
var event aether.Event
if err := json.Unmarshal(msgs[0].Data, &event); err != nil {
return 0, fmt.Errorf("failed to unmarshal last event: %w", err)
}
msgs[0].Ack()
return event.Version, nil
}
// getLatestVersionLocked is like GetLatestVersion but assumes the caller already holds jes.mu.
// This is used internally to avoid releasing and reacquiring the lock during SaveEvent,
// which would create a TOCTOU race condition.
func (jes *JetStreamEventStore) getLatestVersionLocked(actorID string) (int64, error) {
// Create subject filter for this actor
subject := fmt.Sprintf("%s.events.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(actorID)),
sanitizeSubject(actorID))
// Create consumer to read only the last message
consumer, err := jes.js.PullSubscribe(subject, "", nats.DeliverLast())
if err != nil {
return 0, fmt.Errorf("failed to create consumer: %w", err)
}
defer consumer.Unsubscribe()
// Fetch only the last message
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
if err != nil {
if err == nats.ErrTimeout {
// No messages for this actor, return 0
return 0, nil
}
return 0, fmt.Errorf("failed to fetch last message: %w", err)
}
if len(msgs) == 0 {
// No events for this actor
return 0, nil
}
// Parse the last message to get the version
var event aether.Event
if err := json.Unmarshal(msgs[0].Data, &event); err != nil {
return 0, fmt.Errorf("failed to unmarshal last event: %w", err)
}
msgs[0].Ack()
return event.Version, nil
}
// GetLatestSnapshot gets the most recent snapshot for an actor.
// Returns an error if no snapshot exists for the actor (unlike GetLatestVersion which returns 0).
// This is intentional: a missing snapshot is different from a missing event stream.
// If an actor has no events, that's a normal state (use version 0).
// If an actor has no snapshot, that could indicate an error or it could be normal
// depending on the use case, so we let the caller decide how to handle it.
func (jes *JetStreamEventStore) GetLatestSnapshot(actorID string) (*aether.ActorSnapshot, error) {
// Create subject for snapshots
subject := fmt.Sprintf("%s.snapshots.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(actorID)),
sanitizeSubject(actorID))
// Try to get the latest snapshot
consumer, err := jes.js.PullSubscribe(subject, "", nats.DeliverLast())
if err != nil {
return nil, fmt.Errorf("failed to create snapshot consumer: %w", err)
}
defer consumer.Unsubscribe()
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
if err != nil {
if err == nats.ErrTimeout {
// No snapshot found - return error to distinguish from successful nil result
return nil, fmt.Errorf("no snapshot found for actor %s", actorID)
}
return nil, fmt.Errorf("failed to fetch snapshot: %w", err)
}
if len(msgs) == 0 {
// No snapshot exists for this actor
return nil, fmt.Errorf("no snapshot found for actor %s", actorID)
}
var snapshot aether.ActorSnapshot
if err := json.Unmarshal(msgs[0].Data, &snapshot); err != nil {
return nil, fmt.Errorf("failed to unmarshal snapshot: %w", err)
}
msgs[0].Ack()
return &snapshot, nil
}
// SaveSnapshot saves a snapshot of actor state
func (jes *JetStreamEventStore) SaveSnapshot(snapshot *aether.ActorSnapshot) error {
// Serialize snapshot to JSON
data, err := json.Marshal(snapshot)
if err != nil {
return fmt.Errorf("failed to marshal snapshot: %w", err)
}
// Create subject for snapshots
subject := fmt.Sprintf("%s.snapshots.%s.%s",
jes.streamName,
sanitizeSubject(extractActorType(snapshot.ActorID)),
sanitizeSubject(snapshot.ActorID))
// Publish snapshot
_, err = jes.js.Publish(subject, data)
if err != nil {
return fmt.Errorf("failed to publish snapshot to JetStream: %w", err)
}
return nil
}
// Helper functions
// extractActorType extracts the actor type from an actor ID
func extractActorType(actorID string) string {
for i, c := range actorID {
if c == '-' && i > 0 {
return actorID[:i]
}
}
return "unknown"
}
// sanitizeSubject sanitizes a string for use in NATS subjects
func sanitizeSubject(s string) string {
s = strings.ReplaceAll(s, " ", "_")
s = strings.ReplaceAll(s, ".", "_")
s = strings.ReplaceAll(s, "*", "_")
s = strings.ReplaceAll(s, ">", "_")
return s
}
// Compile-time check that JetStreamEventStore implements EventStoreWithErrors
var _ aether.EventStoreWithErrors = (*JetStreamEventStore)(nil)