Skip to content

Commit

Permalink
Add fiber safety to crystal/once (crystal-lang#15370)
Browse files Browse the repository at this point in the history
  • Loading branch information
ysbaddaden committed Jan 30, 2025
1 parent 2978cd1 commit f6adb40
Show file tree
Hide file tree
Showing 7 changed files with 120 additions and 113 deletions.
6 changes: 3 additions & 3 deletions src/compiler/crystal/codegen/class_var.cr
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
initialized_flag.thread_local = true if class_var.thread_local?
end
Expand Down Expand Up @@ -61,7 +61,7 @@ class Crystal::CodeGenVisitor
initialized_flag_name = class_var_global_initialized_name(class_var)
initialized_flag = @llvm_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @llvm_mod.globals.add(llvm_context.int8, initialized_flag_name)
initialized_flag = @llvm_mod.globals.add(llvm_context.int1, initialized_flag_name)
initialized_flag.thread_local = true if class_var.thread_local?
end
end
Expand Down
4 changes: 2 additions & 2 deletions src/compiler/crystal/codegen/const.cr
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ class Crystal::CodeGenVisitor
initialized_flag_name = const.initialized_llvm_name
initialized_flag = @main_mod.globals[initialized_flag_name]?
unless initialized_flag
initialized_flag = @main_mod.globals.add(@main_llvm_context.int8, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int8.const_int(0)
initialized_flag = @main_mod.globals.add(@main_llvm_context.int1, initialized_flag_name)
initialized_flag.initializer = @main_llvm_context.int1.const_int(0)
initialized_flag.linkage = LLVM::Linkage::Internal if @single_module
end
initialized_flag
Expand Down
3 changes: 0 additions & 3 deletions src/compiler/crystal/codegen/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,6 @@ class Crystal::CodeGenVisitor
end

state = load(once_state_type, once_state_global)
{% if LibLLVM::IS_LT_150 %}
flag = bit_cast(flag, @llvm_context.int1.pointer) # cast Int8* to Bool*
{% end %}
args = [state, flag, initializer]
end

Expand Down
2 changes: 1 addition & 1 deletion src/crystal/main.cr
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ module Crystal
# so we explicitly initialize their class vars, then init crystal/once
Thread.init
Fiber.init
Crystal.once_init
Crystal::Once.init
end

# :nodoc:
Expand Down
210 changes: 110 additions & 100 deletions src/crystal/once.cr
Original file line number Diff line number Diff line change
Expand Up @@ -7,138 +7,148 @@
# with older compiler releases. It is executed only once at the beginning of the
# program and, for the legacy implementation, the result is passed on each call
# to `__crystal_once`.
#
# In multithread mode a mutex is used to avoid race conditions between threads.
#
# On Win32, `Crystal::System::FileDescriptor#@@reader_thread` spawns a new
# thread even without the `preview_mt` flag, and the thread can also reference
# Crystal constants, leading to race conditions, so we always enable the mutex.

{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# This implementation uses an enum over the initialization flag pointer for
# each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
enum OnceState : Int8
Processing = -1
Uninitialized = 0
Initialized = 1
require "crystal/pointer_linked_list"
require "crystal/spin_lock"

module Crystal
# :nodoc:
module Once
struct Operation
include PointerLinkedList::Node

getter fiber : Fiber
getter flag : Bool*

def initialize(@flag : Bool*, @fiber : Fiber)
@waiting = PointerLinkedList(Fiber::PointerLinkedListNode).new
end

def add_waiter(node) : Nil
@waiting.push(node)
end

def resume_all : Nil
@waiting.each(&.value.enqueue)
end
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = uninitialized Mutex
{% end %}
@@spin = uninitialized SpinLock
@@operations = uninitialized PointerLinkedList(Operation)

# :nodoc:
def self.once_init : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex = Mutex.new(:reentrant)
{% end %}
def self.init : Nil
@@spin = SpinLock.new
@@operations = PointerLinkedList(Operation).new
end

# :nodoc:
# Using @[NoInline] so LLVM optimizes for the hot path (var already
# initialized).
@[NoInline]
def self.once(flag : OnceState*, initializer : Void*) : Nil
{% if flag?(:preview_mt) || flag?(:win32) %}
@@once_mutex.synchronize { once_exec(flag, initializer) }
{% else %}
once_exec(flag, initializer)
{% end %}
protected def self.exec(flag : Bool*, &)
@@spin.lock

if flag.value
@@spin.unlock
elsif operation = processing?(flag)
check_reentrancy(operation)
wait_initializer(operation)
else
run_initializer(flag) { yield }
end

# safety check, and allows to safely call `Intrinsics.unreachable` in
# `__crystal_once`
unless flag.value.initialized?
System.print_error "BUG: failed to initialize constant or class variable\n"
LibC._exit(1)
return if flag.value

System.print_error "BUG: failed to initialize class variable or constant\n"
LibC._exit(1)
end

private def self.processing?(flag)
@@operations.each do |operation|
return operation if operation.value.flag == flag
end
end

private def self.once_exec(flag : OnceState*, initializer : Void*) : Nil
case flag.value
in .initialized?
return
in .uninitialized?
flag.value = :processing
Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = :initialized
in .processing?
private def self.check_reentrancy(operation)
if operation.value.fiber == Fiber.current
@@spin.unlock
raise "Recursion while initializing class variables and/or constants"
end
end

private def self.wait_initializer(operation)
waiting = Fiber::PointerLinkedListNode.new(Fiber.current)
operation.value.add_waiter(pointerof(waiting))
@@spin.unlock
Fiber.suspend
end

private def self.run_initializer(flag, &)
operation = Operation.new(flag, Fiber.current)
@@operations.push pointerof(operation)
@@spin.unlock

yield

@@spin.lock
flag.value = true
@@operations.delete pointerof(operation)
@@spin.unlock

operation.resume_all
end
end

# :nodoc:
#
# Using `@[AlwaysInline]` allows LLVM to optimize const accesses. Since this
# is a `fun` the function will still appear in the symbol table, though it
# will never be called.
@[AlwaysInline]
fun __crystal_once(flag : Crystal::OnceState*, initializer : Void*) : Nil
return if flag.value.initialized?

Crystal.once(flag, initializer)
# Never inlined to avoid bloating the call site with the slow-path that should
# usually not be taken.
@[NoInline]
def self.once(flag : Bool*, initializer : Void*)
Once.exec(flag, &Proc(Nil).new(initializer, Pointer(Void).null))
end

# tell LLVM that it can optimize away repeated `__crystal_once` calls for
# this global (e.g. repeated access to constant in a single funtion);
# this is truly unreachable otherwise `Crystal.once` would have panicked
Intrinsics.unreachable unless flag.value.initialized?
# :nodoc:
#
# NOTE: should also never be inlined, but that would capture the block, which
# would be a breaking change when we use this method to protect class getter
# and class property macros with lazy initialization (the block may return or
# break).
#
# TODO: consider a compile time flag to enable/disable the capture? returning
# from the block is unexpected behavior: the returned value won't be saved in
# the class variable.
def self.once(flag : Bool*, &)
Once.exec(flag) { yield } unless flag.value
end
{% else %}
# This implementation uses a global array to store the initialization flag
# pointers for each value to find infinite loops and raise an error.

module Crystal
# :nodoc:
class OnceState
@rec = [] of Bool*

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
if @rec.includes?(flag)
raise "Recursion while initializing class variables and/or constants"
end
@rec << flag

Proc(Nil).new(initializer, Pointer(Void).null).call
flag.value = true

@rec.pop
end
end
end

{% if flag?(:preview_mt) || flag?(:win32) %}
@mutex = Mutex.new(:reentrant)

@[NoInline]
def once(flag : Bool*, initializer : Void*)
unless flag.value
@mutex.synchronize do
previous_def
end
end
end
{% end %}
end
{% if compare_versions(Crystal::VERSION, "1.16.0-dev") >= 0 %}
# :nodoc:
#
# We always inline this accessor to optimize for the fast-path (already
# initialized).
@[AlwaysInline]
fun __crystal_once(flag : Bool*, initializer : Void*)
return if flag.value
Crystal.once(flag, initializer)

# :nodoc:
def self.once_init : Nil
end
# tells LLVM to assume that the flag is true, this avoids repeated access to
# the same constant or class variable to check the flag and try to run the
# initializer (only the first access will)
Intrinsics.unreachable unless flag.value
end

{% else %}
# :nodoc:
#
# Unused. Kept for backward compatibility with older compilers.
fun __crystal_once_init : Void*
Crystal::OnceState.new.as(Void*)
Pointer(Void).null
end

# :nodoc:
@[AlwaysInline]
fun __crystal_once(state : Void*, flag : Bool*, initializer : Void*)
return if flag.value
state.as(Crystal::OnceState).once(flag, initializer)
Crystal.once(flag, initializer)
Intrinsics.unreachable unless flag.value
end
{% end %}
6 changes: 3 additions & 3 deletions src/crystal/spin_lock.cr
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@ struct Crystal::SpinLock
private UNLOCKED = 0
private LOCKED = 1

{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
@m = Atomic(Int32).new(UNLOCKED)
{% end %}

def lock
{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
while @m.swap(LOCKED, :acquire) == LOCKED
while @m.get(:relaxed) == LOCKED
Intrinsics.pause
Expand All @@ -18,7 +18,7 @@ struct Crystal::SpinLock
end

def unlock
{% if flag?(:preview_mt) %}
{% if flag?(:preview_mt) || flag?(:win32) %}
@m.set(UNLOCKED, :release)
{% end %}
end
Expand Down
2 changes: 1 addition & 1 deletion src/prelude.cr
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
# appear in the API docs.

# This list requires ordered statements
require "crystal/once"
require "lib_c"
require "macros"
require "object"
require "crystal/once"
require "comparable"
require "exception"
require "iterable"
Expand Down

0 comments on commit f6adb40

Please sign in to comment.