diff options
Diffstat (limited to 'source/core/slang-random-generator.cpp')
| -rw-r--r-- | source/core/slang-random-generator.cpp | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/source/core/slang-random-generator.cpp b/source/core/slang-random-generator.cpp index ce43067aa..ec06336f1 100644 --- a/source/core/slang-random-generator.cpp +++ b/source/core/slang-random-generator.cpp @@ -71,6 +71,57 @@ int64_t RandomGenerator::nextInt64InRange(int64_t min, int64_t max) return (nextPositiveInt64() % diff) + min; } +static uint8_t* _nextData(RandomGenerator* rand, uint8_t* out, size_t size) +{ + if (size) + { + SLANG_ASSERT(size <= 4); + uint32_t v = uint32_t(rand->nextInt32()); + uint8_t* dst = (uint8_t*)out; + for (size_t i = 0; i < size; ++i) + { + dst[i] = uint8_t(v); + v >>= 8; + } + } + return out + size; +} + +void RandomGenerator::nextData(void* out, size_t size) +{ + uint8_t* dst = (uint8_t*)out; + uint8_t*const end = dst + size; + + // For short runs just output + if (size <= 4) + { + _nextData(this, dst, size); + return; + } + + { + const size_t preAlign = size_t(((size_t(dst) + 3) & ~size_t(3)) - size_t(dst)); + dst = _nextData(this, dst, preAlign); + } + + // Check invariants + SLANG_ASSERT((size_t(dst) & 3) == 0 && end >= dst); + + { + const size_t middleCount = size_t(end - dst) >> 2; + if (middleCount) + { + nextInt32s((int32_t*)dst, middleCount); + dst += middleCount * sizeof(int32_t); + } + } + + // Check invariants + SLANG_ASSERT((size_t(dst) & 3) == 0 && end >= dst); + + _nextData(this, dst, size_t(end - dst)); +} + /* static */RandomGenerator* RandomGenerator::create(int32_t seed) { return new DefaultRandomGenerator(seed); @@ -155,7 +206,34 @@ int32_t Mt19937RandomGenerator::nextInt32() return int32_t(y); } +void Mt19937RandomGenerator::nextInt32s(int32_t* dst, size_t count) +{ + while (count) + { + if (m_index >= kNumEntries) + { + _generate(); + } + + const size_t remaining = kNumEntries - m_index; + const size_t run = (count < remaining) ? count : remaining; + + const uint32_t* src = m_mt + m_index; + for (size_t i = 0; i < run; i++) + { + uint32_t y = src[i]; + y = y ^ (y >> 11); + y = y ^ ((y << 7) & uint32_t(0x9d2c5680u)); + y = y ^ ((y << 15) & uint32_t(0xefc6000u)); + y = y ^ (y >> 18); + dst[i] = int32_t(y); + } + m_index += int(run); + dst += run; + count -= run; + } +} } // namespace Slang |
