fix: address critical TOCTOU race condition and error handling inconsistencies
- Fix TOCTOU race condition in SaveEvent by holding the lock throughout entire version validation and publish operation - Add getLatestVersionLocked helper method to prevent race window where multiple concurrent threads read the same currentVersion - Fix GetLatestSnapshot to return error when no snapshot exists (not nil), distinguishing "not created" from "error occurred" - The concurrent version conflict test now passes with exactly 1 success and 49 conflicts instead of 50 successes These changes ensure thread-safe optimistic concurrency control and consistent error handling semantics. Co-Authored-By: Claude Code <noreply@anthropic.com>
This commit is contained in:
@@ -148,25 +148,23 @@ func (jes *JetStreamEventStore) GetStreamName() string {
|
|||||||
// than the current latest version for the actor.
|
// than the current latest version for the actor.
|
||||||
func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error {
|
func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error {
|
||||||
jes.mu.Lock()
|
jes.mu.Lock()
|
||||||
|
defer jes.mu.Unlock()
|
||||||
|
|
||||||
// Check cache first
|
// Check cache first
|
||||||
if version, ok := jes.versions[event.ActorID]; ok {
|
if version, ok := jes.versions[event.ActorID]; ok {
|
||||||
// Validate version against cached version
|
// Validate version against cached version
|
||||||
if event.Version <= version {
|
if event.Version <= version {
|
||||||
jes.mu.Unlock()
|
|
||||||
return &aether.VersionConflictError{
|
return &aether.VersionConflictError{
|
||||||
ActorID: event.ActorID,
|
ActorID: event.ActorID,
|
||||||
AttemptedVersion: event.Version,
|
AttemptedVersion: event.Version,
|
||||||
CurrentVersion: version,
|
CurrentVersion: version,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Version check passed, proceed with publish
|
// Version check passed, proceed with publish while holding lock
|
||||||
jes.mu.Unlock()
|
|
||||||
} else {
|
} else {
|
||||||
// Cache miss - need to check actual stream
|
// Cache miss - need to check actual stream
|
||||||
jes.mu.Unlock()
|
// Get current latest version while holding lock to prevent TOCTOU race
|
||||||
|
currentVersion, err := jes.getLatestVersionLocked(event.ActorID)
|
||||||
// Get current latest version without holding lock
|
|
||||||
currentVersion, err := jes.GetLatestVersion(event.ActorID)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to get latest version: %w", err)
|
return fmt.Errorf("failed to get latest version: %w", err)
|
||||||
}
|
}
|
||||||
@@ -180,10 +178,8 @@ func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update cache after successful validation
|
// Update cache with current version
|
||||||
jes.mu.Lock()
|
|
||||||
jes.versions[event.ActorID] = currentVersion
|
jes.versions[event.ActorID] = currentVersion
|
||||||
jes.mu.Unlock()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Serialize event to JSON
|
// Serialize event to JSON
|
||||||
@@ -205,9 +201,7 @@ func (jes *JetStreamEventStore) SaveEvent(event *aether.Event) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update version cache after successful publish
|
// Update version cache after successful publish
|
||||||
jes.mu.Lock()
|
|
||||||
jes.versions[event.ActorID] = event.Version
|
jes.versions[event.ActorID] = event.Version
|
||||||
jes.mu.Unlock()
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -337,8 +331,54 @@ func (jes *JetStreamEventStore) GetLatestVersion(actorID string) (int64, error)
|
|||||||
return event.Version, nil
|
return event.Version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getLatestVersionLocked is like GetLatestVersion but assumes the caller already holds jes.mu.
|
||||||
|
// This is used internally to avoid releasing and reacquiring the lock during SaveEvent,
|
||||||
|
// which would create a TOCTOU race condition.
|
||||||
|
func (jes *JetStreamEventStore) getLatestVersionLocked(actorID string) (int64, error) {
|
||||||
|
// Create subject filter for this actor
|
||||||
|
subject := fmt.Sprintf("%s.events.%s.%s",
|
||||||
|
jes.streamName,
|
||||||
|
sanitizeSubject(extractActorType(actorID)),
|
||||||
|
sanitizeSubject(actorID))
|
||||||
|
|
||||||
|
// Create consumer to read only the last message
|
||||||
|
consumer, err := jes.js.PullSubscribe(subject, "", nats.DeliverLast())
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to create consumer: %w", err)
|
||||||
|
}
|
||||||
|
defer consumer.Unsubscribe()
|
||||||
|
|
||||||
|
// Fetch only the last message
|
||||||
|
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
|
||||||
|
if err != nil {
|
||||||
|
if err == nats.ErrTimeout {
|
||||||
|
// No messages for this actor, return 0
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return 0, fmt.Errorf("failed to fetch last message: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(msgs) == 0 {
|
||||||
|
// No events for this actor
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the last message to get the version
|
||||||
|
var event aether.Event
|
||||||
|
if err := json.Unmarshal(msgs[0].Data, &event); err != nil {
|
||||||
|
return 0, fmt.Errorf("failed to unmarshal last event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msgs[0].Ack()
|
||||||
|
return event.Version, nil
|
||||||
|
}
|
||||||
|
|
||||||
// GetLatestSnapshot gets the most recent snapshot for an actor.
|
// GetLatestSnapshot gets the most recent snapshot for an actor.
|
||||||
// Returns nil if no snapshot exists for the actor (consistent with GetLatestVersion).
|
// Returns an error if no snapshot exists for the actor (unlike GetLatestVersion which returns 0).
|
||||||
|
// This is intentional: a missing snapshot is different from a missing event stream.
|
||||||
|
// If an actor has no events, that's a normal state (use version 0).
|
||||||
|
// If an actor has no snapshot, that could indicate an error or it could be normal
|
||||||
|
// depending on the use case, so we let the caller decide how to handle it.
|
||||||
func (jes *JetStreamEventStore) GetLatestSnapshot(actorID string) (*aether.ActorSnapshot, error) {
|
func (jes *JetStreamEventStore) GetLatestSnapshot(actorID string) (*aether.ActorSnapshot, error) {
|
||||||
// Create subject for snapshots
|
// Create subject for snapshots
|
||||||
subject := fmt.Sprintf("%s.snapshots.%s.%s",
|
subject := fmt.Sprintf("%s.snapshots.%s.%s",
|
||||||
@@ -356,15 +396,15 @@ func (jes *JetStreamEventStore) GetLatestSnapshot(actorID string) (*aether.Actor
|
|||||||
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
|
msgs, err := consumer.Fetch(1, nats.MaxWait(time.Second))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == nats.ErrTimeout {
|
if err == nats.ErrTimeout {
|
||||||
// No snapshot found - return nil (consistent with GetLatestVersion returning 0)
|
// No snapshot found - return error to distinguish from successful nil result
|
||||||
return nil, nil
|
return nil, fmt.Errorf("no snapshot found for actor %s", actorID)
|
||||||
}
|
}
|
||||||
return nil, fmt.Errorf("failed to fetch snapshot: %w", err)
|
return nil, fmt.Errorf("failed to fetch snapshot: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(msgs) == 0 {
|
if len(msgs) == 0 {
|
||||||
// No snapshot exists for this actor
|
// No snapshot exists for this actor
|
||||||
return nil, nil
|
return nil, fmt.Errorf("no snapshot found for actor %s", actorID)
|
||||||
}
|
}
|
||||||
|
|
||||||
var snapshot aether.ActorSnapshot
|
var snapshot aether.ActorSnapshot
|
||||||
|
|||||||
Reference in New Issue
Block a user