Skip to content

Commit 4e7d892

Browse files
jamienicolteoxoy
andauthored
[naga msl-out hlsl-out] Improve workaround for infinite loops causing undefined behaviour (#6929)
Co-authored-by: Teodor Tanasoaia <[email protected]>
1 parent ad194a8 commit 4e7d892

20 files changed

+223
-95
lines changed

naga/src/back/hlsl/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ pub struct Options {
287287
pub zero_initialize_workgroup_memory: bool,
288288
/// Should we restrict indexing of vectors, matrices and arrays?
289289
pub restrict_indexing: bool,
290+
/// If set, loops will have code injected into them, forcing the compiler
291+
/// to think the number of iterations is bounded.
292+
pub force_loop_bounding: bool,
290293
}
291294

292295
impl Default for Options {
@@ -302,6 +305,7 @@ impl Default for Options {
302305
dynamic_storage_buffer_offsets_targets: std::collections::BTreeMap::new(),
303306
zero_initialize_workgroup_memory: true,
304307
restrict_indexing: true,
308+
force_loop_bounding: true,
305309
}
306310
}
307311
}

naga/src/back/hlsl/writer.rs

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,33 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
143143
self.need_bake_expressions.clear();
144144
}
145145

146+
/// Generates statements to be inserted immediately before and at the very
147+
/// start of the body of each loop, to defeat infinite loop reasoning.
148+
/// The 0th item of the returned tuple should be inserted immediately prior
149+
/// to the loop and the 1st item should be inserted at the very start of
150+
/// the loop body.
151+
///
152+
/// See [`back::msl::Writer::gen_force_bounded_loop_statements`] for details.
153+
fn gen_force_bounded_loop_statements(
154+
&mut self,
155+
level: back::Level,
156+
) -> Option<(String, String)> {
157+
if !self.options.force_loop_bounding {
158+
return None;
159+
}
160+
161+
let loop_bound_name = self.namer.call("loop_bound");
162+
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u, 0u);");
163+
let level = level.next();
164+
let max = u32::MAX;
165+
let break_and_inc = format!(
166+
"{level}if (all({loop_bound_name} == uint2({max}u, {max}u))) {{ break; }}
167+
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
168+
);
169+
170+
Some((decl, break_and_inc))
171+
}
172+
146173
/// Helper method used to find which expressions of a given function require baking
147174
///
148175
/// # Notes
@@ -2162,12 +2189,24 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
21622189
ref continuing,
21632190
break_if,
21642191
} => {
2192+
let force_loop_bound_statements = self.gen_force_bounded_loop_statements(level);
2193+
let gate_name = (!continuing.is_empty() || break_if.is_some())
2194+
.then(|| self.namer.call("loop_init"));
2195+
2196+
if let Some((ref decl, _)) = force_loop_bound_statements {
2197+
writeln!(self.out, "{decl}")?;
2198+
}
2199+
if let Some(ref gate_name) = gate_name {
2200+
writeln!(self.out, "{level}bool {gate_name} = true;")?;
2201+
}
2202+
21652203
self.continue_ctx.enter_loop();
2204+
writeln!(self.out, "{level}while(true) {{")?;
2205+
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
2206+
writeln!(self.out, "{break_and_inc}")?;
2207+
}
21662208
let l2 = level.next();
2167-
if !continuing.is_empty() || break_if.is_some() {
2168-
let gate_name = self.namer.call("loop_init");
2169-
writeln!(self.out, "{level}bool {gate_name} = true;")?;
2170-
writeln!(self.out, "{level}while(true) {{")?;
2209+
if let Some(gate_name) = gate_name {
21712210
writeln!(self.out, "{l2}if (!{gate_name}) {{")?;
21722211
let l3 = l2.next();
21732212
for sta in continuing.iter() {
@@ -2182,13 +2221,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
21822221
}
21832222
writeln!(self.out, "{l2}}}")?;
21842223
writeln!(self.out, "{l2}{gate_name} = false;")?;
2185-
} else {
2186-
writeln!(self.out, "{level}while(true) {{")?;
21872224
}
21882225

21892226
for sta in body.iter() {
21902227
self.write_stmt(module, sta, func_ctx, l2)?;
21912228
}
2229+
21922230
writeln!(self.out, "{level}}}")?;
21932231
self.continue_ctx.exit_loop();
21942232
}

naga/src/back/msl/writer.rs

Lines changed: 59 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,6 @@ pub struct Writer<W> {
383383
/// Set of (struct type, struct field index) denoting which fields require
384384
/// padding inserted **before** them (i.e. between fields at index - 1 and index)
385385
struct_member_pads: FastHashSet<(Handle<crate::Type>, u32)>,
386-
387-
/// Name of the force-bounded-loop macro.
388-
///
389-
/// See `emit_force_bounded_loop_macro` for details.
390-
force_bounded_loop_macro_name: String,
391386
}
392387

393388
impl crate::Scalar {
@@ -601,7 +596,7 @@ struct ExpressionContext<'a> {
601596
/// accesses. These may need to be cached in temporary variables. See
602597
/// `index::find_checked_indexes` for details.
603598
guarded_indices: HandleSet<crate::Expression>,
604-
/// See [`Writer::emit_force_bounded_loop_macro`] for details.
599+
/// See [`Writer::gen_force_bounded_loop_statements`] for details.
605600
force_loop_bounding: bool,
606601
}
607602

@@ -685,7 +680,6 @@ impl<W: Write> Writer<W> {
685680
#[cfg(test)]
686681
put_block_stack_pointers: Default::default(),
687682
struct_member_pads: FastHashSet::default(),
688-
force_bounded_loop_macro_name: String::default(),
689683
}
690684
}
691685

@@ -696,17 +690,11 @@ impl<W: Write> Writer<W> {
696690
self.out
697691
}
698692

699-
/// Define a macro to invoke at the bottom of each loop body, to
700-
/// defeat MSL infinite loop reasoning.
701-
///
702-
/// If we haven't done so already, emit the definition of a preprocessor
703-
/// macro to be invoked at the end of each loop body in the generated MSL,
704-
/// to ensure that the MSL compiler's optimizations do not remove bounds
705-
/// checks.
706-
///
707-
/// Only the first call to this function for a given module actually causes
708-
/// the macro definition to be written. Subsequent loops can simply use the
709-
/// prior macro definition, since macros aren't block-scoped.
693+
/// Generates statements to be inserted immediately before and at the very
694+
/// start of the body of each loop, to defeat MSL infinite loop reasoning.
695+
/// The 0th item of the returned tuple should be inserted immediately prior
696+
/// to the loop and the 1st item should be inserted at the very start of
697+
/// the loop body.
710698
///
711699
/// # What is this trying to solve?
712700
///
@@ -774,7 +762,8 @@ impl<W: Write> Writer<W> {
774762
/// but which in fact generates no instructions. Unfortunately, inline
775763
/// assembly is not handled correctly by some Metal device drivers.
776764
///
777-
/// Instead, we add the following code to the bottom of every loop:
765+
/// A previously used approach was to add the following code to the bottom
766+
/// of every loop:
778767
///
779768
/// ```ignore
780769
/// if (volatile bool unpredictable = false; unpredictable)
@@ -785,37 +774,47 @@ impl<W: Write> Writer<W> {
785774
/// the `volatile` qualifier prevents the compiler from assuming this. Thus,
786775
/// it must assume that the `break` might be reached, and hence that the
787776
/// loop is not unbounded. This prevents the range analysis impact described
788-
/// above.
777+
/// above. Unfortunately this prevented the compiler from making important,
778+
/// and safe, optimizations such as loop unrolling and was observed to
779+
/// significantly hurt performance.
789780
///
790-
/// Unfortunately, what makes this a kludge, not a hack, is that this
791-
/// solution leaves the GPU executing a pointless conditional branch, at
792-
/// runtime, in every iteration of the loop. There's no part of the system
793-
/// that has a global enough view to be sure that `unpredictable` is true,
794-
/// and remove it from the code. Adding the branch also affects
795-
/// optimization: for example, it's impossible to unroll this loop. This
796-
/// transformation has been observed to significantly hurt performance.
781+
/// Our current approach declares a counter before every loop and
782+
/// increments it every iteration, breaking after 2^64 iterations:
783+
///
784+
/// ```ignore
785+
/// uint2 loop_bound = uint2(0);
786+
/// while (true) {
787+
/// if (metal::all(loop_bound == uint2(4294967295))) { break; }
788+
/// loop_bound += uint2(loop_bound.y == 4294967295, 1);
789+
/// }
790+
/// ```
797791
///
798-
/// To make our output a bit more legible, we pull the condition out into a
799-
/// preprocessor macro defined at the top of the module.
792+
/// This convinces the compiler that the loop is finite and therefore may
793+
/// execute, whilst at the same time allowing optimizations such as loop
794+
/// unrolling. Furthermore the 64-bit counter is large enough it seems
795+
/// implausible that it would affect the execution of any shader.
800796
///
801797
/// This approach is also used by Chromium WebGPU's Dawn shader compiler:
802-
/// <https://dawn.googlesource.com/dawn/+/a37557db581c2b60fb1cd2c01abdb232927dd961/src/tint/lang/msl/writer/printer/printer.cc#222>
803-
fn emit_force_bounded_loop_macro(&mut self) -> BackendResult {
804-
if !self.force_bounded_loop_macro_name.is_empty() {
805-
return Ok(());
798+
/// <https://dawn.googlesource.com/dawn/+/d9e2d1f718678ebee0728b999830576c410cce0a/src/tint/lang/core/ir/transform/prevent_infinite_loops.cc>
799+
fn gen_force_bounded_loop_statements(
800+
&mut self,
801+
level: back::Level,
802+
context: &StatementContext,
803+
) -> Option<(String, String)> {
804+
if !context.expression.force_loop_bounding {
805+
return None;
806806
}
807807

808-
self.force_bounded_loop_macro_name = self.namer.call("LOOP_IS_BOUNDED");
809-
let loop_bounded_volatile_name = self.namer.call("unpredictable_break_from_loop");
810-
writeln!(
811-
self.out,
812-
"#define {} {{ volatile bool {} = false; if ({}) break; }}",
813-
self.force_bounded_loop_macro_name,
814-
loop_bounded_volatile_name,
815-
loop_bounded_volatile_name,
816-
)?;
808+
let loop_bound_name = self.namer.call("loop_bound");
809+
let decl = format!("{level}uint2 {loop_bound_name} = uint2(0u);");
810+
let level = level.next();
811+
let max = u32::MAX;
812+
let break_and_inc = format!(
813+
"{level}if ({NAMESPACE}::all({loop_bound_name} == uint2({max}u))) {{ break; }}
814+
{level}{loop_bound_name} += uint2({loop_bound_name}.y == {max}u, 1u);"
815+
);
817816

818-
Ok(())
817+
Some((decl, break_and_inc))
819818
}
820819

821820
fn put_call_parameters(
@@ -3201,10 +3200,23 @@ impl<W: Write> Writer<W> {
32013200
ref continuing,
32023201
break_if,
32033202
} => {
3204-
if !continuing.is_empty() || break_if.is_some() {
3205-
let gate_name = self.namer.call("loop_init");
3203+
let force_loop_bound_statements =
3204+
self.gen_force_bounded_loop_statements(level, context);
3205+
let gate_name = (!continuing.is_empty() || break_if.is_some())
3206+
.then(|| self.namer.call("loop_init"));
3207+
3208+
if let Some((ref decl, _)) = force_loop_bound_statements {
3209+
writeln!(self.out, "{decl}")?;
3210+
}
3211+
if let Some(ref gate_name) = gate_name {
32063212
writeln!(self.out, "{level}bool {gate_name} = true;")?;
3207-
writeln!(self.out, "{level}while(true) {{",)?;
3213+
}
3214+
3215+
writeln!(self.out, "{level}while(true) {{",)?;
3216+
if let Some((_, ref break_and_inc)) = force_loop_bound_statements {
3217+
writeln!(self.out, "{break_and_inc}")?;
3218+
}
3219+
if let Some(ref gate_name) = gate_name {
32083220
let lif = level.next();
32093221
let lcontinuing = lif.next();
32103222
writeln!(self.out, "{lif}if (!{gate_name}) {{")?;
@@ -3218,19 +3230,9 @@ impl<W: Write> Writer<W> {
32183230
}
32193231
writeln!(self.out, "{lif}}}")?;
32203232
writeln!(self.out, "{lif}{gate_name} = false;")?;
3221-
} else {
3222-
writeln!(self.out, "{level}while(true) {{",)?;
32233233
}
32243234
self.put_block(level.next(), body, context)?;
3225-
if context.expression.force_loop_bounding {
3226-
self.emit_force_bounded_loop_macro()?;
3227-
writeln!(
3228-
self.out,
3229-
"{}{}",
3230-
level.next(),
3231-
self.force_bounded_loop_macro_name
3232-
)?;
3233-
}
3235+
32343236
writeln!(self.out, "{level}}}")?;
32353237
}
32363238
crate::Statement::Break => {
@@ -3724,7 +3726,6 @@ impl<W: Write> Writer<W> {
37243726
&[CLAMPED_LOD_LOAD_PREFIX],
37253727
&mut self.names,
37263728
);
3727-
self.force_bounded_loop_macro_name.clear();
37283729
self.struct_member_pads.clear();
37293730

37303731
writeln!(

naga/tests/out/hlsl/boids.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,11 @@ void main(uint3 global_invocation_id : SV_DispatchThreadID)
4141
vPos = _e8;
4242
float2 _e14 = asfloat(particlesSrc.Load2(8+index*16+0));
4343
vVel = _e14;
44+
uint2 loop_bound = uint2(0u, 0u);
4445
bool loop_init = true;
4546
while(true) {
47+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
48+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
4649
if (!loop_init) {
4750
uint _e91 = i;
4851
i = (_e91 + 1u);

naga/tests/out/hlsl/break-if.hlsl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
void breakIfEmpty()
22
{
3+
uint2 loop_bound = uint2(0u, 0u);
34
bool loop_init = true;
45
while(true) {
6+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
7+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
58
if (!loop_init) {
69
if (true) {
710
break;
@@ -17,8 +20,11 @@ void breakIfEmptyBody(bool a)
1720
bool b = (bool)0;
1821
bool c = (bool)0;
1922

23+
uint2 loop_bound_1 = uint2(0u, 0u);
2024
bool loop_init_1 = true;
2125
while(true) {
26+
if (all(loop_bound_1 == uint2(4294967295u, 4294967295u))) { break; }
27+
loop_bound_1 += uint2(loop_bound_1.y == 4294967295u, 1u);
2228
if (!loop_init_1) {
2329
b = a;
2430
bool _e2 = b;
@@ -38,8 +44,11 @@ void breakIf(bool a_1)
3844
bool d = (bool)0;
3945
bool e = (bool)0;
4046

47+
uint2 loop_bound_2 = uint2(0u, 0u);
4148
bool loop_init_2 = true;
4249
while(true) {
50+
if (all(loop_bound_2 == uint2(4294967295u, 4294967295u))) { break; }
51+
loop_bound_2 += uint2(loop_bound_2.y == 4294967295u, 1u);
4352
if (!loop_init_2) {
4453
bool _e5 = e;
4554
if ((a_1 == _e5)) {
@@ -58,8 +67,11 @@ void breakIfSeparateVariable()
5867
{
5968
uint counter = 0u;
6069

70+
uint2 loop_bound_3 = uint2(0u, 0u);
6171
bool loop_init_3 = true;
6272
while(true) {
73+
if (all(loop_bound_3 == uint2(4294967295u, 4294967295u))) { break; }
74+
loop_bound_3 += uint2(loop_bound_3.y == 4294967295u, 1u);
6375
if (!loop_init_3) {
6476
uint _e5 = counter;
6577
if ((_e5 == 5u)) {

naga/tests/out/hlsl/collatz.hlsl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@ uint collatz_iterations(uint n_base)
66
uint i = 0u;
77

88
n = n_base;
9+
uint2 loop_bound = uint2(0u, 0u);
910
while(true) {
11+
if (all(loop_bound == uint2(4294967295u, 4294967295u))) { break; }
12+
loop_bound += uint2(loop_bound.y == 4294967295u, 1u);
1013
uint _e4 = n;
1114
if ((_e4 > 1u)) {
1215
} else {

0 commit comments

Comments
 (0)