From 5f4f454b9662b83facd71dcbc6e8ac105660e938 Mon Sep 17 00:00:00 2001
From: Kaylee Lubick <kjlubick@google.com>
Date: Wed, 20 May 2026 08:23:28 -0400
Subject: [PATCH] Fix for integer wraparound in sksl

I found the ES 2 spec [1] helpful for reference here.

The calculate_count_neq_int is not strictly necessary (I was unable
to find a case that tricked the existing floats with ints), but
I like the refactoring and it mirrors the gt/lt cases nicely.

[1] https://registry.khronos.org/OpenGL/specs/es/2.0/GLSL_ES_Specification_1.00.pdf

Change-Id: I0b9117f347e4b7d5d336de0f14337b9bec510ff2
Bug: https://issues.chromium.org/issues/513337118
Fixed: 513337118
Reviewed-on: https://skia-review.googlesource.com/c/skia/+/1236656
Commit-Queue: Florin Malita <fmalita@google.com>
Reviewed-by: Florin Malita <fmalita@google.com>
Commit-Queue: Kaylee Lubick <kjlubick@google.com>
Auto-Submit: Kaylee Lubick <kjlubick@google.com>
---
 gn/sksl_tests.gni                            |   2 +
 resources/sksl/BUILD.bazel                   |   2 +
 resources/sksl/errors/ForLoopNEQOverflow.rts |  11 ++
 resources/sksl/errors/ForLoopOverflow.rts    |  11 ++
 src/base/SkSafeMath.h                        |   9 ++
 src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp  | 130 ++++++++++++++++---
 tests/sksl/errors/ForLoopNEQOverflow.glsl    |   6 +
 tests/sksl/errors/ForLoopOverflow.glsl       |   6 +
 8 files changed, 157 insertions(+), 20 deletions(-)
 create mode 100644 resources/sksl/errors/ForLoopNEQOverflow.rts
 create mode 100644 resources/sksl/errors/ForLoopOverflow.rts
 create mode 100644 tests/sksl/errors/ForLoopNEQOverflow.glsl
 create mode 100644 tests/sksl/errors/ForLoopOverflow.glsl

diff --git a/gn/sksl_tests.gni b/gn/sksl_tests.gni
index 1cd9317f04..1a1804f842 100644
--- a/gn/sksl_tests.gni
+++ b/gn/sksl_tests.gni
@@ -142,6 +142,8 @@ sksl_error_tests = [
   "errors/FloatRemainder.rts",
   "errors/ForInitStmt.sksl",
   "errors/ForLoopInductionVariableScope.sksl",
+  "errors/ForLoopNEQOverflow.rts",
+  "errors/ForLoopOverflow.rts",
   "errors/ForTypeMismatch.rts",
   "errors/FunctionParamBadType.rts",
   "errors/FunctionParamShadowedByLocal.rts",
diff --git a/src/base/SkSafeMath.h b/src/base/SkSafeMath.h
index d8f9fbc965..590098a5ca 100644
--- a/src/base/SkSafeMath.h
+++ b/src/base/SkSafeMath.h
@@ -41,6 +41,7 @@ public:
      *  be set to false, and it is undefined what this returns.
      */
     int addInt(int a, int b) {
+        static_assert(sizeof(int) == 4, "int is not 4 bytes");
         if (b < 0 && a < std::numeric_limits<int>::min() - b) {
             fOK = false;
             return a;
@@ -51,6 +52,14 @@ public:
         return a + b;
     }
 
+    int subInt(int a, int b) {
+        if (b == std::numeric_limits<int>::min()) {
+            fOK = false;
+            return a;
+        }
+        return addInt(a, -b);
+    }
+
     int mulInt(int x, int y) {
         int64_t result = (int64_t)x * (int64_t)y;
         if (result > std::numeric_limits<int>::max() || result < std::numeric_limits<int>::min()) {
diff --git a/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp b/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp
index 70a2170ba3..974db63e36 100644
--- a/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp
+++ b/src/sksl/analysis/SkSLGetLoopUnrollInfo.cpp
@@ -7,6 +7,7 @@
 
 #include "include/core/SkTypes.h"
 #include "include/private/base/SkFloatingPoint.h"
+#include "src/base/SkSafeMath.h"
 #include "src/sksl/SkSLAnalysis.h"
 #include "src/sksl/SkSLConstantFolder.h"
 #include "src/sksl/SkSLErrorReporter.h"
@@ -35,25 +36,106 @@ class Context;
 // Loops that run for 100000+ iterations will exceed our program size limit.
 static constexpr int kLoopTerminationLimit = 100000;
 
-static int calculate_count(double start, double end, double delta, bool forwards, bool inclusive) {
-    if ((forwards && start > end) || (!forwards && start < end)) {
+enum class Direction {
+    kBackwards,
+    kForwards,
+};
+
+enum class Inclusive : bool {
+    kNo = false,
+    kYes = true,
+};
+
+enum class LoopType {
+    kFloat,
+    kInt,
+};
+
+static int calculate_count_float(double start, double end, double delta,
+                                 Inclusive inclusive) {
+    double iterations = sk_ieee_double_divide(end - start, delta);
+    double count = std::ceil(iterations);
+    if (inclusive == Inclusive::kYes && (count == iterations)) {
+        count += 1.0;
+    }
+    if (count > kLoopTerminationLimit || !std::isfinite(count)) {
+        // The loop runs for more iterations than we can safely unroll.
+        return kLoopTerminationLimit;
+    }
+    return sk_double_saturate2int(count);
+}
+
+static int calculate_count_int(int32_t start, int32_t end, int32_t delta,
+                               Inclusive inclusive) {
+    if (delta == 0) {
+        return kLoopTerminationLimit;
+    }
+    SkSafeMath math;
+    int roundUp = delta > 0 ? math.subInt(delta, 1) : math.addInt(delta, 1);
+    int width = math.subInt(end, start);
+    int iterations = math.addInt(width, roundUp) / delta;
+    if (inclusive == Inclusive::kYes && width % delta == 0) {
+        iterations = math.addInt(iterations, 1);
+    }
+    // Check that we won't overflow while looping
+    math.addInt(start, math.mulInt(delta, iterations));
+    if (!math || iterations < 0 || iterations > kLoopTerminationLimit) {
+        return kLoopTerminationLimit;
+    }
+    return iterations;
+}
+
+static int calculate_count(double start, double end, double delta, Direction dir,
+                           Inclusive inclusive, LoopType loop) {
+    if ((dir == Direction::kForwards && start > end) ||
+        (dir == Direction::kBackwards && start < end)) {
         // The loop starts in a completed state (the start has already advanced past the end).
         return 0;
     }
-    if ((delta == 0.0) || forwards != (delta > 0.0)) {
+    if ((delta == 0.0) ||
+        (delta > 0.0 && dir == Direction::kBackwards) ||
+        (delta < 0.0 && dir == Direction::kForwards)) {
         // The loop does not progress toward a completed state, and will never terminate.
         return kLoopTerminationLimit;
     }
+    if (loop == LoopType::kInt) {
+        return calculate_count_int((int32_t)start, (int32_t)end, (int32_t)delta, inclusive);
+    }
+    return calculate_count_float(start, end, delta, inclusive);
+}
+
+static int calculate_count_neq_int(int32_t start, int32_t end, int32_t delta) {
+    if (delta == 0) {
+        return kLoopTerminationLimit;
+    }
+    SkSafeMath math;
+    int iterations = math.subInt(end, start) / delta;
+    // Check that we won't overflow while looping and that we actually hit end.
+    int lastValue = math.addInt(start, math.mulInt(delta, iterations));
+    if (!math || lastValue != end || iterations < 0 || iterations > kLoopTerminationLimit) {
+        return kLoopTerminationLimit;
+    }
+    return iterations;
+}
+
+static int calculate_count_neq_float(double start, double end, double delta) {
+    if (delta == 0.0) {
+        return kLoopTerminationLimit;
+    }
     double iterations = sk_ieee_double_divide(end - start, delta);
     double count = std::ceil(iterations);
-    if (inclusive && (count == iterations)) {
-        count += 1.0;
-    }
-    if (count > kLoopTerminationLimit || !std::isfinite(count)) {
-        // The loop runs for more iterations than we can safely unroll.
+    if (count < 0 || count != iterations || !std::isfinite(iterations)) {
+        // The loop doesn't reach the exact endpoint and so will never terminate.
         return kLoopTerminationLimit;
     }
-    return (int)count;
+    return sk_double_saturate2int(count);
+}
+
+static int calculate_count_neq(double start, double end, double delta, LoopType loop) {
+    if (loop == LoopType::kInt) {
+        return calculate_count_neq_int((int32_t)start, (int32_t)end, (int32_t)delta);
+    }
+    return calculate_count_neq_float(start, end, delta);
 }
 
 std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& context,
@@ -226,35 +308,43 @@ std::unique_ptr<LoopUnrollInfo> Analysis::GetLoopUnrollInfo(const Context& conte
     // Finally, compute the iteration count, based on the bounds, and the termination operator.
     loopInfo->fCount = 0;
 
+    // Strict ES2 requires loop induction variables to be either 'int' or 'float'. For 'int'
+    // variables, we simulate the loop using 32-bit signed math to correctly detect the integer
+    // wraparound behavior that would occur at runtime on the GPU. (For 'float' variables,
+    // the existing double-precision calculation is sufficient.)
+    LoopType loop;
+    if (initDecl.baseType().isSigned()) {
+        SkASSERT(initDecl.baseType().bitWidth() == 32);
+        loop = LoopType::kInt;
+    } else {
+        SkASSERT(initDecl.baseType().isFloat());
+        loop = LoopType::kFloat;
+    }
+
     switch (cond->getOperator().kind()) {
         case Operator::Kind::LT:
             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
-                                              /*forwards=*/true, /*inclusive=*/false);
+                                               Direction::kForwards, Inclusive::kNo, loop);
             break;
 
         case Operator::Kind::GT:
             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
-                                              /*forwards=*/false, /*inclusive=*/false);
+                                               Direction::kBackwards, Inclusive::kNo, loop);
             break;
 
         case Operator::Kind::LTEQ:
             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
-                                              /*forwards=*/true, /*inclusive=*/true);
+                                               Direction::kForwards, Inclusive::kYes, loop);
             break;
 
         case Operator::Kind::GTEQ:
             loopInfo->fCount = calculate_count(loopInfo->fStart, loopEnd, loopInfo->fDelta,
-                                              /*forwards=*/false, /*inclusive=*/true);
+                                               Direction::kBackwards, Inclusive::kYes, loop);
             break;
 
         case Operator::Kind::NEQ: {
-            float iterations = sk_ieee_double_divide(loopEnd - loopInfo->fStart, loopInfo->fDelta);
-            loopInfo->fCount = std::ceil(iterations);
-            if (loopInfo->fCount < 0 || loopInfo->fCount != iterations ||
-                !std::isfinite(iterations)) {
-                // The loop doesn't reach the exact endpoint and so will never terminate.
-                loopInfo->fCount = kLoopTerminationLimit;
-            }
+            loopInfo->fCount = calculate_count_neq(loopInfo->fStart, loopEnd, loopInfo->fDelta,
+                                                   loop);
             if (loopInfo->fIndex->type().componentType().isFloat()) {
                 // Rewrite `x != n` tests as `x < n` or `x > n` depending on the loop direction.
                 // Less-than and greater-than tests avoid infinite loops caused by rounding error.
-- 
2.43.0

