summaryrefslogtreecommitdiffstats
path: root/source/core
diff options
context:
space:
mode:
Diffstat (limited to 'source/core')
-rw-r--r--source/core/slang-linked-list.h2
-rw-r--r--source/core/slang-uint-set.cpp62
-rw-r--r--source/core/slang-uint-set.h209
3 files changed, 230 insertions, 43 deletions
diff --git a/source/core/slang-linked-list.h b/source/core/slang-linked-list.h
index 93b5e435c..840ef8cd6 100644
--- a/source/core/slang-linked-list.h
+++ b/source/core/slang-linked-list.h
@@ -323,7 +323,7 @@ public:
}
return rs;
}
- int getCount() { return count; }
+ int getCount() const { return count; }
};
} // namespace Slang
#endif
diff --git a/source/core/slang-uint-set.cpp b/source/core/slang-uint-set.cpp
index e973cbc3a..b6871c192 100644
--- a/source/core/slang-uint-set.cpp
+++ b/source/core/slang-uint-set.cpp
@@ -3,18 +3,6 @@
namespace Slang
{
-static bool _areAllZero(const UIntSet::Element* elems, Index count)
-{
- for (Index i = 0; count; ++i)
- {
- if (elems[i])
- {
- return false;
- }
- }
- return true;
-}
-
UIntSet& UIntSet::operator=(UIntSet&& other)
{
m_buffer = _Move(other.m_buffer);
@@ -49,14 +37,8 @@ void UIntSet::setAll()
void UIntSet::resize(UInt size)
{
- const Index oldCount = m_buffer.getCount();
const Index newCount = Index((size + kElementMask) >> kElementShift);
- m_buffer.setCount(newCount);
-
- if (newCount > oldCount)
- {
- ::memset(m_buffer.getBuffer() + oldCount, 0, (newCount - oldCount) * sizeof(Element));
- }
+ resizeBackingBufferDirectly(newCount);
}
void UIntSet::clear()
@@ -66,17 +48,7 @@ void UIntSet::clear()
bool UIntSet::isEmpty() const
{
- const Element*const src = m_buffer.getBuffer();
- const Index count = m_buffer.getCount();
-
- for (Index i = 0; i < count; ++i)
- {
- if (src[i])
- {
- return false;
- }
- }
- return true;
+ return _areAllZero(m_buffer.getBuffer(), m_buffer.getCount());
}
void UIntSet::clearAndDeallocate()
@@ -106,7 +78,7 @@ bool UIntSet::operator==(const UIntSet& set) const
const Index minCount = Math::Min(aCount, bCount);
- return ::memcmp(aElems, bElems, minCount) == 0 &&
+ return ::memcmp(aElems, bElems, minCount*sizeof(Element)) == 0 &&
_areAllZero(aElems + minCount, aCount - minCount) &&
_areAllZero(bElems + minCount, bCount - minCount);
}
@@ -123,6 +95,15 @@ void UIntSet::intersectWith(const UIntSet& set)
}
}
+void UIntSet::subtractWith(const UIntSet& set)
+{
+ const Index minCount = Math::Min(this->m_buffer.getCount(), set.m_buffer.getCount());
+ for (Index i = 0; i < minCount; i++)
+ {
+ this->m_buffer[i] = this->m_buffer[i] & (~set.m_buffer[i]);
+ }
+}
+
/* static */void UIntSet::calcUnion(UIntSet& outRs, const UIntSet& set1, const UIntSet& set2)
{
outRs.m_buffer.setCount(Math::Max(set1.m_buffer.getCount(), set2.m_buffer.getCount()));
@@ -162,5 +143,24 @@ void UIntSet::intersectWith(const UIntSet& set)
return false;
}
+Index UIntSet::countElements() const
+{
+ // TODO: This can be made faster using SIMD intrinsics to count set bits.
+ uint64_t tmp;
+ constexpr Index loopSize = ((sizeof(Element) / sizeof(tmp)) != 0) ? sizeof(Element) / sizeof(tmp) : 1;
+ Index count = 0;
+ for (auto index = 0; index < this->m_buffer.getCount(); index++)
+ {
+ for (auto i = 0; i < loopSize; i++)
+ {
+ tmp = m_buffer[index] >> (sizeof(tmp) * i);
+ tmp = tmp - ((tmp >> 1) & 0x5555555555555555);
+ tmp = (tmp & 0x3333333333333333) + ((tmp >> 2) & 0x3333333333333333);
+ count += ((tmp + (tmp >> 4) & 0xF0F0F0F0F0F0F0F) * 0x101010101010101) >> 56;
+ }
+ }
+ return count;
+}
+
}
diff --git a/source/core/slang-uint-set.h b/source/core/slang-uint-set.h
index 0f2165bab..22ca457b0 100644
--- a/source/core/slang-uint-set.h
+++ b/source/core/slang-uint-set.h
@@ -6,31 +6,83 @@
#include "slang-common.h"
#include "slang-hash.h"
+#if defined(_MSC_VER)
+#include <intrin.h>
+#endif
#include <memory.h>
namespace Slang
{
+template<typename T>
+constexpr static Index computeElementShift()
+{
+ Index currentShift = 0;
+ Index currentShiftValue = 1;
+
+ while (currentShiftValue != sizeof(T) * 8)
+ {
+ currentShift++;
+ currentShiftValue *= 2;
+ }
+
+ return currentShift;
+}
+
+static inline Index bitscanForward(uint64_t in)
+{
+#if defined(_MSC_VER)
+
+#ifdef _WIN64
+ uint64_t out = 0;
+ _BitScanForward64((unsigned long*)&out, in);
+ return Index(out);
+#else
+ constexpr uint32_t bitsInType = sizeof(uint32_t) * 8;
+ uint32_t out;
+ // check for 0s in 0bit->31bit. If all 0's, check for 0s in 32bit->63bit
+ _BitScanForward((unsigned long*)&out, *(((uint32_t*)&in) + 1));
+ if (out != bitsInType)
+ return Index(out);
+ _BitScanForward((unsigned long*)&out, *(((uint32_t*)&in)));
+ return Index(out + bitsInType);
+#endif// #ifdef _WIN64
+
+#else
+ return Index(__builtin_ctzll(in));
+#endif// #if defined(_MSC_VER)
+}
+
/* Hold a set of UInt values. Implementation works by storing as a bit per value */
+/// UIntSet is essentially a Element[], where each Element is `b` bits big.
+/// Each index has `b` number of integers. If the bit is 1, we have an element there.
+/// Value of each element is equal to the binary offset from Element[0], bit 0.
class UIntSet
{
public:
typedef UIntSet ThisType;
- typedef uint32_t Element; ///< Type that holds the bits to say if value is present
+ typedef uint64_t Element; ///< Type that holds the bits to say if value is present
+ constexpr static Index kElementSize = sizeof(Element) * 8; ///< The number of bits in an element. This also determines how many values a element can hold.
+ constexpr static Index kElementMask = kElementSize - 1; ///< Mask to get shift from an index
+ constexpr static Index kElementShift = computeElementShift<Element>(); ///< How many bits to shift to get Element index from an index. 5 for 2^5=32 elements in a uint32_t. 6 for 2^6=64 in a uint64_t.
+
UIntSet() {}
UIntSet(const UIntSet& other) { m_buffer = other.m_buffer; }
UIntSet(UIntSet && other) { *this = (_Move(other)); }
UIntSet(UInt maxVal) { resizeAndClear(maxVal); }
+ UIntSet(List<UIntSet::Element> buffer) { m_buffer = buffer; }
UIntSet& operator=(UIntSet&& other);
UIntSet& operator=(const UIntSet& other);
HashCode getHashCode() const;
- /// Return the count of all bits directly represented
+ /// Return the count of all bits directly represented
Int getCount() const { return Int(m_buffer.getCount()) * kElementSize; }
+ List<Element>& getBuffer() { return m_buffer; }
+
/// Resize such that val can be stored and clear contents
void resizeAndClear(UInt val);
/// Set all of the values up to count, as set
@@ -38,6 +90,7 @@ public:
/// Resize (but maintain contents) up to bit size.
/// NOTE! That since storage is in Element blocks, it may mean some values after size are set (up to the Element boundary)
void resize(UInt size);
+ void resizeBackingBufferDirectly(Index size);
/// Clear all of the contents (by clearing the bits)
void clear();
@@ -47,6 +100,8 @@ public:
/// Add a value
inline void add(UInt val);
+ inline void add(const UIntSet& val);
+
/// Remove a value
inline void remove(UInt val);
/// Returns true if the value is present
@@ -59,10 +114,12 @@ public:
/// !=
bool operator!=(const UIntSet& set) const { return !(*this == set); }
- /// Store the union between this and set in this
+ /// Store the union between this and set
void unionWith(const UIntSet& set);
- /// Store the intersection between this and set in this
+ /// Store the intersection between this and set
void intersectWith(const UIntSet& set);
+ /// Store the subtraction between this and set
+ void subtractWith(const UIntSet& set);
///
bool isEmpty() const;
@@ -70,6 +127,10 @@ public:
/// Swap this with rhs
void swapWith(ThisType& rhs) { m_buffer.swapWith(rhs.m_buffer); }
+ template<typename T>
+ List<T> getElements() const;
+ Index countElements() const;
+
/// Store the union of set1 and set2 in outRs
static void calcUnion(UIntSet& outRs, const UIntSet& set1, const UIntSet& set2);
/// Store the intersection of set1 and set2 in outRs
@@ -80,16 +141,98 @@ public:
/// Returns true if set1 and set2 have a same value set (ie there is an intersection)
static bool hasIntersection(const UIntSet& set1, const UIntSet& set2);
-private:
- enum
+ struct Iterator
{
- kElementShift = 5, ///< How many bits to shift to get Element index from an index
- kElementSize = sizeof(Element) * 8, ///< The number of bits in an element
- kElementMask = kElementSize - 1, ///< Mask to get shift from an index
+ friend class UIntSet;
+ private:
+ const List<Element>* context;
+ Index block = 0;
+ Element processedElement = 0;
+ uint64_t LSB = 0;
+
+ void clearLSB()
+ {
+ LSB = bitscanForward(processedElement);
+ processedElement &= processedElement - 1;
+ }
+ public:
+ Iterator(const List<Element>* inContext)
+ {
+ context = inContext;
+ }
+
+ Element operator*()
+ {
+ return Element(LSB + (kElementSize * block));
+ }
+
+ Iterator& operator++()
+ {
+ while (processedElement == 0)
+ {
+ block++;
+ if (block >= context->getCount())
+ {
+ return *this;
+ }
+ processedElement = (*context)[block];
+ }
+ clearLSB();
+ return *this;
+ }
+ Iterator& operator++(int)
+ {
+ return ++(*this);
+ }
+ bool operator==(const Iterator& other) const
+ {
+ return other.block == this->block
+ && other.processedElement == this->processedElement;
+ }
+ bool operator!=(const Iterator& other) const
+ {
+ return !(other == *this);
+ }
};
+ Iterator begin() const
+ {
+ Iterator tmp(&m_buffer);
+ if (m_buffer.getCount() == 0)
+ return tmp;
+
+ tmp.processedElement = m_buffer[0];
+ if (tmp.processedElement == 0)
+ tmp++;
+
+ tmp.clearLSB();
- // Make sure they are correct for the Element type
- SLANG_COMPILE_TIME_ASSERT((1 << kElementShift) == kElementSize);
+ return tmp;
+ }
+ Iterator end() const
+ {
+ Iterator tmp(&m_buffer);
+ tmp.block = m_buffer.getCount();
+ tmp.processedElement = 0;
+ return tmp;
+ }
+
+ bool areAllZero()
+ {
+ return _areAllZero(m_buffer.getBuffer(), m_buffer.getCount());
+ }
+
+protected:
+ static bool _areAllZero(const UIntSet::Element* elems, Index count)
+ {
+ for (Index i = 0; i < count; ++i)
+ {
+ if (elems[i])
+ {
+ return false;
+ }
+ }
+ return true;
+ }
List<Element> m_buffer;
};
@@ -132,6 +275,18 @@ inline bool UIntSet::contains(const UIntSet& set) const
}
// --------------------------------------------------------------------------
+
+inline void UIntSet::resizeBackingBufferDirectly(Index newCount)
+{
+ const Index oldCount = m_buffer.getCount();
+ m_buffer.setCount(newCount);
+
+ if (newCount > oldCount)
+ {
+ ::memset(m_buffer.getBuffer() + oldCount, 0, (newCount - oldCount) * sizeof(Element));
+ }
+}
+
inline void UIntSet::add(UInt val)
{
const Index idx = Index(val >> kElementShift);
@@ -142,6 +297,38 @@ inline void UIntSet::add(UInt val)
m_buffer[idx] |= Element(1) << (val & kElementMask);
}
+inline void UIntSet::add(const UIntSet& other)
+{
+ auto otherCount = other.m_buffer.getCount();
+ if (this->m_buffer.getCount() < otherCount)
+ resizeBackingBufferDirectly(otherCount);
+
+ for (auto i = 0; i < otherCount; i++)
+ m_buffer[i] |= other.m_buffer[i];
}
+template<typename T>
+List<T> UIntSet::getElements() const
+{
+ auto count = m_buffer.getCount();
+ if (count == 0)
+ return {};
+
+ // Specific path for uint64_t. If using SIMD we should not use this path due to larger data types.
+
+ List<T> elements;
+ elements.reserve(count);
+ for (Index block = 0; block < count; block++)
+ {
+ Element n = m_buffer[block];
+ while (n != 0)
+ {
+ elements.add(T(bitscanForward((uint64_t)n) + (kElementSize * block)));
+ n &= n - 1;
+ }
+ }
+ return elements;
+}
+
+}
#endif