summaryrefslogtreecommitdiffstats
path: root/source/core/slang-random-generator.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/core/slang-random-generator.cpp')
-rw-r--r--source/core/slang-random-generator.cpp78
1 files changed, 78 insertions, 0 deletions
diff --git a/source/core/slang-random-generator.cpp b/source/core/slang-random-generator.cpp
index ce43067aa..ec06336f1 100644
--- a/source/core/slang-random-generator.cpp
+++ b/source/core/slang-random-generator.cpp
@@ -71,6 +71,57 @@ int64_t RandomGenerator::nextInt64InRange(int64_t min, int64_t max)
return (nextPositiveInt64() % diff) + min;
}
+static uint8_t* _nextData(RandomGenerator* rand, uint8_t* out, size_t size)
+{
+ if (size)
+ {
+ SLANG_ASSERT(size <= 4);
+ uint32_t v = uint32_t(rand->nextInt32());
+ uint8_t* dst = (uint8_t*)out;
+ for (size_t i = 0; i < size; ++i)
+ {
+ dst[i] = uint8_t(v);
+ v >>= 8;
+ }
+ }
+ return out + size;
+}
+
+void RandomGenerator::nextData(void* out, size_t size)
+{
+ uint8_t* dst = (uint8_t*)out;
+ uint8_t*const end = dst + size;
+
+ // For short runs just output
+ if (size <= 4)
+ {
+ _nextData(this, dst, size);
+ return;
+ }
+
+ {
+ const size_t preAlign = size_t(((size_t(dst) + 3) & ~size_t(3)) - size_t(dst));
+ dst = _nextData(this, dst, preAlign);
+ }
+
+ // Check invariants
+ SLANG_ASSERT((size_t(dst) & 3) == 0 && end >= dst);
+
+ {
+ const size_t middleCount = size_t(end - dst) >> 2;
+ if (middleCount)
+ {
+ nextInt32s((int32_t*)dst, middleCount);
+ dst += middleCount * sizeof(int32_t);
+ }
+ }
+
+ // Check invariants
+ SLANG_ASSERT((size_t(dst) & 3) == 0 && end >= dst);
+
+ _nextData(this, dst, size_t(end - dst));
+}
+
/* static */RandomGenerator* RandomGenerator::create(int32_t seed)
{
return new DefaultRandomGenerator(seed);
@@ -155,7 +206,34 @@ int32_t Mt19937RandomGenerator::nextInt32()
return int32_t(y);
}
+void Mt19937RandomGenerator::nextInt32s(int32_t* dst, size_t count)
+{
+ while (count)
+ {
+ if (m_index >= kNumEntries)
+ {
+ _generate();
+ }
+
+ const size_t remaining = kNumEntries - m_index;
+ const size_t run = (count < remaining) ? count : remaining;
+
+ const uint32_t* src = m_mt + m_index;
+ for (size_t i = 0; i < run; i++)
+ {
+ uint32_t y = src[i];
+ y = y ^ (y >> 11);
+ y = y ^ ((y << 7) & uint32_t(0x9d2c5680u));
+ y = y ^ ((y << 15) & uint32_t(0xefc6000u));
+ y = y ^ (y >> 18);
+ dst[i] = int32_t(y);
+ }
+ m_index += int(run);
+ dst += run;
+ count -= run;
+ }
+}
} // namespace Slang