diff options
| author | Yong He <yonghe@outlook.com> | 2024-10-15 18:54:16 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-10-15 18:54:16 -0700 |
| commit | c97166aed29e0a224d49cec0b12503d1a10b52e0 (patch) | |
| tree | 1894ff8a3b608d66f55f5f2bd47640e679e59e78 /tests | |
| parent | 99a242eca78149a61c0521d319e96ededec7168d (diff) | |
Fix type checking on generic extensions. (#5316)
Add fcpw library to test suite.
Diffstat (limited to 'tests')
| -rw-r--r-- | tests/bugs/gh-5140.slang | 33 | ||||
| -rw-r--r-- | tests/fcpw/LICENSE | 21 | ||||
| -rw-r--r-- | tests/fcpw/aggregate.slang | 147 | ||||
| -rw-r--r-- | tests/fcpw/bounding-volumes.slang | 300 | ||||
| -rw-r--r-- | tests/fcpw/bvh-node.slang | 193 | ||||
| -rw-r--r-- | tests/fcpw/bvh-refit.cs.slang | 45 | ||||
| -rw-r--r-- | tests/fcpw/bvh-traversal.cs.slang | 131 | ||||
| -rw-r--r-- | tests/fcpw/bvh.slang | 675 | ||||
| -rw-r--r-- | tests/fcpw/fcpw.slang | 2 | ||||
| -rw-r--r-- | tests/fcpw/geometry.slang | 708 | ||||
| -rw-r--r-- | tests/fcpw/interaction.slang | 21 | ||||
| -rw-r--r-- | tests/fcpw/math-constants.slang | 7 | ||||
| -rw-r--r-- | tests/fcpw/ray.slang | 18 | ||||
| -rw-r--r-- | tests/fcpw/transform.slang | 56 |
14 files changed, 2357 insertions, 0 deletions
diff --git a/tests/bugs/gh-5140.slang b/tests/bugs/gh-5140.slang new file mode 100644 index 000000000..23d9b3a23 --- /dev/null +++ b/tests/bugs/gh-5140.slang @@ -0,0 +1,33 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-vk -output-using-type +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK):-d3d11 -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer + +public interface A<T: IFloat> +{ +} + +public extension<T1: IFloat, a1 : A<T1>> a1 +{ + void foo() { + outputBuffer[0] = 1.0; + } +} + +RWStructuredBuffer<float> outputBuffer; +struct S : A<float> { +} + +void helper<T: IFloat, a : A<T>>(a a2) +{ + a2.foo(); +} + +// CHECK: 1 + +[numthreads(1,1,1)] +void computeMain() +{ + S a; + helper(a); +}
\ No newline at end of file diff --git a/tests/fcpw/LICENSE b/tests/fcpw/LICENSE new file mode 100644 index 000000000..8db71e851 --- /dev/null +++ b/tests/fcpw/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2020 Rohan Sawhney + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/tests/fcpw/aggregate.slang b/tests/fcpw/aggregate.slang new file mode 100644 index 000000000..f83fa126a --- /dev/null +++ b/tests/fcpw/aggregate.slang @@ -0,0 +1,147 @@ +implementing fcpw; +__include ray; +__include interaction; +__include bounding_volumes; +__include transform; + +public interface IBranchTraversalWeight +{ + // computes the traversal weight for a given squared distance + float compute(float r2); +}; + +public struct ConstantBranchTraversalWeight : IBranchTraversalWeight +{ + // computes the traversal weight for a given squared distance + public float compute(float r2) + { + return 1.0; + } +}; + +public interface IAggregate +{ + // updates the bounding volume of an aggregate node + [mutating] + void refit(uint nodeIndex); + + // intersects aggregate geometry with ray + bool intersect(inout Ray r, bool checkForOcclusion, inout Interaction i); + + // intersects aggregate geometry with sphere + bool intersect<T : IBranchTraversalWeight>(BoundingSphere s, float3 randNums, + T branchTraversalWeight, + inout Interaction i); + + // finds closest point on aggregate geometry from sphere center + bool findClosestPoint(inout BoundingSphere s, inout Interaction i, + bool recordNormal = false); + + // finds closest silhouette point on aggregate geometry from sphere center + bool findClosestSilhouettePoint(inout BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i); +}; + +public struct TransformedAggregate<A : IAggregate> : IAggregate +{ + public A aggregate; + public float3x4 t; + public float3x4 tInv; + + // updates the bounding volume of an aggregate node + // NOTE: refitting of transformed aggregates is currently quite inefficient, since the + // shared aggregate is refit every time this function is called + [mutating] + public void refit(uint nodeIndex) + { + aggregate.refit(nodeIndex); + } + + // intersects aggregate geometry with ray + public bool intersect(inout Ray r, bool checkForOcclusion, inout Interaction i) + { + // apply inverse transform to ray + Ray rInv = transformRay(tInv, r); + + // intersect + bool didIntersect = aggregate.intersect(rInv, checkForOcclusion, i); + + // apply transform to ray and interaction + r.tMax = transformRay(t, rInv).tMax; + if (didIntersect) + { + transformInteraction(t, tInv, r.o, true, i); + return true; + } + + return false; + } + + // intersects aggregate geometry with sphere + public bool intersect<T : IBranchTraversalWeight>(BoundingSphere s, float3 randNums, + T branchTraversalWeight, + inout Interaction i) + { + // apply inverse transform to sphere + BoundingSphere sInv = transformSphere(tInv, s); + + // intersect + bool didIntersect = aggregate.intersect(sInv, randNums, branchTraversalWeight, i); + + // apply transform to interaction + if (didIntersect) + { + transformInteraction(t, tInv, s.c, false, i); + return true; + } + + return false; + } + + // finds closest point on aggregate geometry from sphere center + public bool findClosestPoint(inout BoundingSphere s, inout Interaction i, + bool recordNormal = false) + { + // apply inverse transform to sphere + BoundingSphere sInv = transformSphere(tInv, s); + + // find closest point + bool didFindClosestPoint = aggregate.findClosestPoint(sInv, i, recordNormal); + + // apply transform to sphere and interaction + s.r2 = transformSphere(t, sInv).r2; + if (didFindClosestPoint) + { + transformInteraction(t, tInv, s.c, true, i); + return true; + } + + return false; + } + + // finds closest silhouette point on aggregate geometry from sphere center + public bool findClosestSilhouettePoint(inout BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i) + { + // apply inverse transform to sphere + BoundingSphere sInv = transformSphere(tInv, s); + BoundingSphere sMin = BoundingSphere(s.c, squaredMinRadius); + BoundingSphere sMinInv = transformSphere(tInv, sMin); + + // find closest silhouette point + bool didFindClosestSilhouettePoint = aggregate.findClosestSilhouettePoint( + sInv, flipNormalOrientation, sMinInv.r2, precision, i); + + // apply transform to sphere and interaction + s.r2 = transformSphere(t, sInv).r2; + if (didFindClosestSilhouettePoint) + { + transformInteraction(t, tInv, s.c, true, i); + return true; + } + + return false; + } +}; diff --git a/tests/fcpw/bounding-volumes.slang b/tests/fcpw/bounding-volumes.slang new file mode 100644 index 000000000..39094f228 --- /dev/null +++ b/tests/fcpw/bounding-volumes.slang @@ -0,0 +1,300 @@ +implementing fcpw; +__include ray; +__include math_constants; + +public struct BoundingSphere +{ + public float3 c; // sphere center + public float r2; // sphere squared radius + + // constructor + public __init(float3 c_, float r2_) + { + c = c_; + r2 = r2_; + } +}; + +public struct BoundingBox +{ + public float3 pMin; // aabb min position + public float3 pMax; // aabb max position + + // constructor + public __init(float3 pMin_, float3 pMax_) + { + pMin = pMin_; + pMax = pMax_; + } + + // checks whether box is valid + public bool isValid() + { + return all(pMin <= pMax); + } + + // returns box centroid + public float3 getCentroid() + { + return 0.5 * (pMin + pMax); + } + + // checks for overlap with sphere + public bool overlap(BoundingSphere s, out float d2Min, out float d2Max) + { + float3 u = pMin - s.c; + float3 v = s.c - pMax; + float3 a = max(max(u, v), float3(0.0, 0.0, 0.0)); + float3 b = min(u, v); + d2Min = dot(a, a); + d2Max = dot(b, b); + + return d2Min <= s.r2; + } + + // checks for overlap with sphere + public bool overlap(BoundingSphere s, out float d2Min) + { + float3 u = pMin - s.c; + float3 v = s.c - pMax; + float3 a = max(max(u, v), float3(0.0, 0.0, 0.0)); + d2Min = dot(a, a); + + return d2Min <= s.r2; + } + + // intersects box with ray + public bool intersect(Ray r, out float tMin, out float tMax) + { + // slab test for ray box intersection + // source: http://www.jcgt.org/published/0007/03/04/paper-lowres.pdf + float3 t0 = (pMin - r.o) * r.dInv; + float3 t1 = (pMax - r.o) * r.dInv; + float3 tNear = min(t0, t1); + float3 tFar = max(t0, t1); + + float tNearMax = max(0.0, max(tNear.x, max(tNear.y, tNear.z))); + float tFarMin = min(r.tMax, min(tFar.x, min(tFar.y, tFar.z))); + if (tNearMax > tFarMin) + { + tMin = FLT_MAX; + tMax = FLT_MAX; + return false; + } + + tMin = tNearMax; + tMax = tFarMin; + return true; + } +}; + +public BoundingBox mergeBoundingBoxes(BoundingBox boxA, BoundingBox boxB) +{ + BoundingBox box; + box.pMin = min(boxA.pMin, boxB.pMin); + box.pMax = max(boxA.pMax, boxB.pMax); + + return box; +} + +public bool inRange(float val, float low, float high) +{ + return val >= low && val <= high; +} + +public void computeOrthonormalBasis(float3 n, out float3 b1, out float3 b2) +{ + // source: https://graphics.pixar.com/library/OrthonormalB/paper.pdf + float sign = n.z >= 0.0 ? 1.0 : -1.0; + float a = -1.0 / (sign + n.z); + float b = n.x * n.y * a; + + b1 = float3(1.0 + sign * n.x * n.x * a, sign * b, -sign * n.x); + b2 = float3(b, sign + n.y * n.y * a, -n.y); +} + +public float projectToPlane(float3 n, float3 e) +{ + // compute orthonormal basis + float3 b1, b2; + computeOrthonormalBasis(n, b1, b2); + + // compute maximal projection radius + float r1 = dot(e, abs(b1)); + float r2 = dot(e, abs(b2)); + return sqrt(r1 * r1 + r2 * r2); +} + +public struct BoundingCone +{ + public float3 axis; // cone axis + public float halfAngle; // cone half angle + public float radius; // cone radius + + // constructors + public __init() + { + axis = float3(0.0, 0.0, 0.0); + halfAngle = M_PI; + radius = 0.0; + } + public __init(float3 axis_, float halfAngle_, float radius_) + { + axis = axis_; + halfAngle = halfAngle_; + radius = radius_; + } + + // checks whether cone is valid + public bool isValid() + { + return halfAngle >= 0.0; + } + + // check for overlap between this cone and the "view" cone defined by the given + // point o and bounding box b; the two cones overlap when there exist two vectors, + // one in each cone, that are orthogonal to each other. + public bool overlap(float3 o, BoundingBox b, float distToBox, + out float minAngleRange, out float maxAngleRange) + { + // initialize angle bounds + minAngleRange = 0.0f; + maxAngleRange = M_PI_2; + + // there's overlap if this cone's halfAngle is greater than 90 degrees, or + // if the box contains the view cone origin (since the view cone is invalid) + if (halfAngle >= M_PI_2 || distToBox < FLT_EPSILON) + { + return true; + } + + // compute the view cone axis + float3 c = b.getCentroid(); + float3 viewConeAxis = c - o; + float l = length(viewConeAxis); + viewConeAxis /= l; + + // check for overlap between the view cone axis and this cone + float dAxisAngle = acos(max(-1.0, min(1.0, dot(axis, viewConeAxis)))); // [0, 180] + if (inRange(M_PI_2, dAxisAngle - halfAngle, dAxisAngle + halfAngle)) + { + return true; + } + + // check if the view cone origin lies outside this cone's bounding sphere; + // if it does, compute the view cone halfAngle and check for overlap + if (l > radius) + { + float viewConeHalfAngle = asin(radius / l); + float halfAngleSum = halfAngle + viewConeHalfAngle; + minAngleRange = dAxisAngle - halfAngleSum; + maxAngleRange = dAxisAngle + halfAngleSum; + return halfAngleSum >= M_PI_2 ? true : inRange(M_PI_2, minAngleRange, maxAngleRange); + } + + // the view cone origin lies inside the box's bounding sphere, so check if + // the plane defined by the view cone axis intersects the box; if it does, then + // there's overlap since the view cone has a halfAngle greater than 90 degrees + float3 e = b.pMax - c; + float d = dot(e, abs(viewConeAxis)); // max projection length onto axis + float s = l - d; + if (s <= 0.0) + { + return true; + } + + // compute the view cone halfAngle by projecting the max extents of the box + // onto the plane, and check for overlap + d = projectToPlane(viewConeAxis, e); + float viewConeHalfAngle = atan2(d, s); + float halfAngleSum = halfAngle + viewConeHalfAngle; + minAngleRange = dAxisAngle - halfAngleSum; + maxAngleRange = dAxisAngle + halfAngleSum; + return halfAngleSum >= M_PI_2 ? true : inRange(M_PI_2, minAngleRange, maxAngleRange); + } +}; + +public float3 rotate(float3 u, float3 v, float theta) +{ + float cosTheta = cos(theta); + float sinTheta = sin(theta); + float3 w = length(cross(u, v)); + float3 oneMinusCosThetaW = (1.0 - cosTheta) * w; + + float3x3 R; + R[0][0] = cosTheta + oneMinusCosThetaW[0] * w[0]; + R[0][1] = oneMinusCosThetaW[1] * w[0] - sinTheta * w[2]; + R[0][2] = oneMinusCosThetaW[2] * w[0] + sinTheta * w[1]; + R[1][0] = oneMinusCosThetaW[0] * w[1] + sinTheta * w[2]; + R[1][1] = cosTheta + oneMinusCosThetaW[1] * w[1]; + R[1][2] = oneMinusCosThetaW[2] * w[1] - sinTheta * w[0]; + R[2][0] = oneMinusCosThetaW[0] * w[2] - sinTheta * w[1]; + R[2][1] = oneMinusCosThetaW[1] * w[2] + sinTheta * w[0]; + R[2][2] = cosTheta + oneMinusCosThetaW[2] * w[2]; + + return mul(R, u); +} + +public BoundingCone mergeBoundingCones(BoundingCone coneA, BoundingCone coneB, + float3 originA, float3 originB, + float3 newOrigin) +{ + BoundingCone cone; + if (coneA.isValid() && coneB.isValid()) + { + float3 axisA = coneA.axis; + float3 axisB = coneB.axis; + float halfAngleA = coneA.halfAngle; + float halfAngleB = coneB.halfAngle; + float3 dOriginA = newOrigin - originA; + float3 dOriginB = newOrigin - originB; + cone.radius = sqrt(max(coneA.radius * coneA.radius + dot(dOriginA, dOriginA), + coneB.radius * coneB.radius + dot(dOriginB, dOriginB))); + + if (halfAngleB > halfAngleA) + { + float3 tmpAxis = axisA; + axisA = axisB; + axisB = tmpAxis; + + float tmpHalfAngle = halfAngleA; + halfAngleA = halfAngleB; + halfAngleB = tmpHalfAngle; + } + + float theta = acos(max(-1.0, min(1.0, dot(axisA, axisB)))); + if (min(theta + halfAngleB, M_PI) <= halfAngleA) + { + // right cone is completely inside left cone + cone.axis = axisA; + cone.halfAngle = halfAngleA; + return cone; + } + + // merge cones by first computing the spread angle of the cone to cover both cones + float oTheta = (halfAngleA + theta + halfAngleB) / 2.0; + if (oTheta >= M_PI) + { + cone.axis = axisA; + return cone; + } + + float rTheta = oTheta - halfAngleA; + cone.axis = rotate(axisA, axisB, rTheta); + cone.halfAngle = oTheta; + } + else if (coneA.isValid()) + { + cone = coneA; + } + else if (coneB.isValid()) + { + cone = coneB; + } + else + { + cone.halfAngle = -M_PI; + } + + return cone; +} diff --git a/tests/fcpw/bvh-node.slang b/tests/fcpw/bvh-node.slang new file mode 100644 index 000000000..2805ae5ee --- /dev/null +++ b/tests/fcpw/bvh-node.slang @@ -0,0 +1,193 @@ +implementing fcpw; +__include bounding_volumes; + +public interface IBvhNode +{ + // returns the bounding box of the node + BoundingBox getBoundingBox(); + + // sets the bounding box of the node + [mutating] + void setBoundingBox(BoundingBox box_); + + // checks if the node has a bounding cone + bool hasBoundingCone(); + + // returns the bounding cone of the node + BoundingCone getBoundingCone(); + + // sets the bounding cone of the node + [mutating] + void setBoundingCone(BoundingCone cone_); + + // checks if the node is a leaf node + bool isLeaf(); + + // returns the offset to the right child of the interior node + uint getRightChildOffset(); + + // returns the number of primitives in the node + uint getNumPrimitives(); + + // returns the offset to the first primitive of the leaf node + uint getPrimitiveOffset(); + + // returns the number of silhouettes in the node + uint getNumSilhouettes(); + + // returns the offset to the first silhouette of the leaf node + uint getSilhouetteOffset(); +}; + +public struct BvhNode : IBvhNode +{ + public BoundingBox box; + public uint nPrimitives; + public uint offset; + + // returns the bounding box of the node + public BoundingBox getBoundingBox() + { + return box; + } + + // sets the bounding box of the node + [mutating] + public void setBoundingBox(BoundingBox box_) + { + box = box_; + } + + // checks if the node has a bounding cone + public bool hasBoundingCone() + { + return false; + } + + // returns the bounding cone of the node + public BoundingCone getBoundingCone() + { + return BoundingCone(); + } + + // sets the bounding cone of the node + [mutating] + public void setBoundingCone(BoundingCone cone_) + { + // do nothing + } + + // checks if the node is a leaf node + public bool isLeaf() + { + return nPrimitives > 0; + } + + // returns the offset to the right child of the interior node + public uint getRightChildOffset() + { + return offset; + } + + // returns the number of primitives in the node + public uint getNumPrimitives() + { + return nPrimitives; + } + + // returns the offset to the first primitive of the leaf node + public uint getPrimitiveOffset() + { + return offset; + } + + // returns the number of silhouettes in the node + public uint getNumSilhouettes() + { + return 0; + } + + // returns the offset to the first silhouette of the leaf node + public uint getSilhouetteOffset() + { + return 0; + } +}; + +public struct SnchNode : IBvhNode +{ + public BoundingBox box; + public BoundingCone cone; + public uint nPrimitives; + public uint offset; + public uint nSilhouettes; + public uint silhouetteOffset; + + // returns the bounding box of the node + public BoundingBox getBoundingBox() + { + return box; + } + + // sets the bounding box of the node + [mutating] + public void setBoundingBox(BoundingBox box_) + { + box = box_; + } + + // checks if the node has a bounding cone + public bool hasBoundingCone() + { + return true; + } + + // returns the bounding cone of the node + public BoundingCone getBoundingCone() + { + return cone; + } + + // sets the bounding cone of the node + [mutating] + public void setBoundingCone(BoundingCone cone_) + { + cone = cone_; + } + + // checks if the node is a leaf node + public bool isLeaf() + { + return nPrimitives > 0; + } + + // returns the offset to the right child of the interior node + public uint getRightChildOffset() + { + return offset; + } + + // returns the number of primitives in the node + public uint getNumPrimitives() + { + return nPrimitives; + } + + // returns the offset to the first primitive of the leaf node + public uint getPrimitiveOffset() + { + return offset; + } + + // returns the number of silhouettes in the node + public uint getNumSilhouettes() + { + return nSilhouettes; + } + + // returns the offset to the first silhouette of the leaf node + public uint getSilhouetteOffset() + { + return silhouetteOffset; + } +}; diff --git a/tests/fcpw/bvh-refit.cs.slang b/tests/fcpw/bvh-refit.cs.slang new file mode 100644 index 000000000..2ccd435d2 --- /dev/null +++ b/tests/fcpw/bvh-refit.cs.slang @@ -0,0 +1,45 @@ +import fcpw; + +#define UNDEFINED_BVH_TYPE 0 +#define LINE_SEGMENT_BVH 1 +#define TRIANGLE_BVH 2 +#define LINE_SEGMENT_SNCH 3 +#define TRIANGLE_SNCH 4 + +#ifndef _BVH_TYPE +#define _BVH_TYPE UNDEFINED_BVH_TYPE +#endif + +#if _BVH_TYPE == LINE_SEGMENT_BVH +uniform ParameterBlock<Bvh<BvhNode, LineSegment, NoSilhouette>> gBvh; + +#elif _BVH_TYPE == TRIANGLE_BVH +uniform ParameterBlock<Bvh<BvhNode, Triangle, NoSilhouette>> gBvh; + +#elif _BVH_TYPE == LINE_SEGMENT_SNCH +uniform ParameterBlock<Bvh<SnchNode, LineSegment, Vertex>> gBvh; + +#elif _BVH_TYPE == TRIANGLE_SNCH +uniform ParameterBlock<Bvh<SnchNode, Triangle, Edge>> gBvh; + +#else +// Compile time error +#error _BVH_TYPE is not set to a supported type +#endif + +[shader("compute")] +[numthreads(256, 1, 1)] +void refit(uint3 threadId: SV_DispatchThreadID, + uniform StructuredBuffer<uint> nodeIndices, + uniform uint firstNodeOffset, + uniform uint nodeCount) +{ + uint index = threadId.x; + if (index >= nodeCount) + { + return; + } + + uint nodeIndex = nodeIndices[firstNodeOffset + index]; + gBvh.refit(nodeIndex); +} diff --git a/tests/fcpw/bvh-traversal.cs.slang b/tests/fcpw/bvh-traversal.cs.slang new file mode 100644 index 000000000..e1be3118a --- /dev/null +++ b/tests/fcpw/bvh-traversal.cs.slang @@ -0,0 +1,131 @@ +import fcpw; + +#define UNDEFINED_BVH_TYPE 0 +#define LINE_SEGMENT_BVH 1 +#define TRIANGLE_BVH 2 +#define LINE_SEGMENT_SNCH 3 +#define TRIANGLE_SNCH 4 + +#ifndef _BVH_TYPE +#define _BVH_TYPE UNDEFINED_BVH_TYPE +#endif + +#if _BVH_TYPE == LINE_SEGMENT_BVH +uniform ParameterBlock<Bvh<BvhNode, LineSegment, NoSilhouette>> gBvh; +#define _BVH_HAS_SILHOUETTE_DATA 0 + +#elif _BVH_TYPE == TRIANGLE_BVH +uniform ParameterBlock<Bvh<BvhNode, Triangle, NoSilhouette>> gBvh; +#define _BVH_HAS_SILHOUETTE_DATA 0 + +#elif _BVH_TYPE == LINE_SEGMENT_SNCH +uniform ParameterBlock<Bvh<SnchNode, LineSegment, Vertex>> gBvh; +#define _BVH_HAS_SILHOUETTE_DATA 1 + +#elif _BVH_TYPE == TRIANGLE_SNCH +uniform ParameterBlock<Bvh<SnchNode, Triangle, Edge>> gBvh; +#define _BVH_HAS_SILHOUETTE_DATA 1 + +#else +// Compile time error +#error _BVH_TYPE is not set to a supported type +#endif + +[shader("compute")] +[numthreads(256, 1, 1)] +void rayIntersection(uint3 threadId: SV_DispatchThreadID, + uniform StructuredBuffer<Ray> rays, + uniform bool checkForOcclusion, + uniform RWStructuredBuffer<Interaction> interactions, + uniform uint nQueries) +{ + uint index = threadId.x; + if (index >= nQueries) + { + return; + } + + Ray r = rays[index]; + Interaction i; + bool didIntersect = gBvh.intersect(r, checkForOcclusion, i); + if (didIntersect) + { + interactions[index] = i; + } +} + +[shader("compute")] +[numthreads(256, 1, 1)] +void sphereIntersection(uint3 threadId: SV_DispatchThreadID, + uniform StructuredBuffer<BoundingSphere> boundingSpheres, + uniform StructuredBuffer<float3> randNums, + uniform RWStructuredBuffer<Interaction> interactions, + uniform uint nQueries) +{ + uint index = threadId.x; + if (index >= nQueries) + { + return; + } + + BoundingSphere s = boundingSpheres[index]; + float3 randNum = randNums[index]; + ConstantBranchTraversalWeight branchTraversalWeight; + Interaction i; + bool didIntersect = gBvh.intersect<ConstantBranchTraversalWeight>(s, randNum, branchTraversalWeight, i); + if (didIntersect) + { + interactions[index] = i; + } +} + +[shader("compute")] +[numthreads(256, 1, 1)] +void closestPoint(uint3 threadId: SV_DispatchThreadID, + uniform StructuredBuffer<BoundingSphere> boundingSpheres, + uniform RWStructuredBuffer<Interaction> interactions, + uniform bool recordNormals, + uniform uint nQueries) +{ + uint index = threadId.x; + if (index >= nQueries) + { + return; + } + + BoundingSphere s = boundingSpheres[index]; + Interaction i; + bool found = gBvh.findClosestPoint(s, i, recordNormals); + if (found) + { + interactions[index] = i; + } +} + +[shader("compute")] +[numthreads(256, 1, 1)] +void closestSilhouettePoint(uint3 threadId: SV_DispatchThreadID, + uniform StructuredBuffer<BoundingSphere> boundingSpheres, + uniform StructuredBuffer<uint> flipNormalOrientation, + uniform float squaredMinRadius, + uniform float precision, + uniform RWStructuredBuffer<Interaction> interactions, + uniform uint nQueries) +{ + uint index = threadId.x; + if (index >= nQueries) + { + return; + } + + Interaction i; +#if _BVH_HAS_SILHOUETTE_DATA + BoundingSphere s = boundingSpheres[index]; + bool flipNormal = flipNormalOrientation[index] == 1 ? true : false; + bool found = gBvh.findClosestSilhouettePoint(s, flipNormal, squaredMinRadius, precision, i); + if (found) + { + interactions[index] = i; + } +#endif +} diff --git a/tests/fcpw/bvh.slang b/tests/fcpw/bvh.slang new file mode 100644 index 000000000..f110264c3 --- /dev/null +++ b/tests/fcpw/bvh.slang @@ -0,0 +1,675 @@ +implementing fcpw; +__include aggregate; +__include geometry; +__include bvh_node; + +public static const uint FCPW_BVH_MAX_DEPTH = 64; + +public struct TraversalStack +{ + public uint node; // node index + public float distance; // minimum distance (parametric, squared, ...) to this node + + // constructor + public __init() + { + node = 0; + distance = 0.0; + } +}; + +public struct Bvh<N : IBvhNode, P : IPrimitive, S : ISilhouette> : IAggregate +{ + public RWStructuredBuffer<N> nodes; + public StructuredBuffer<P> primitives; + public StructuredBuffer<S> silhouettes; + + // updates the bounding volume of an aggregate leaf node + [mutating] + internal void refitLeafNode(uint nodeIndex) + { + // update leaf node's bounding box + float3 pMin = float3(FLT_MAX, FLT_MAX, FLT_MAX); + float3 pMax = float3(-FLT_MAX, -FLT_MAX, -FLT_MAX); + N node = nodes[nodeIndex]; + uint nPrimitives = node.getNumPrimitives(); + + for (uint p = 0; p < nPrimitives; p++) + { + uint primitiveIndex = node.getPrimitiveOffset() + p; + BoundingBox primitiveBox = primitives[primitiveIndex].getBoundingBox(); + pMin = min(pMin, primitiveBox.pMin); + pMax = max(pMax, primitiveBox.pMax); + } + + node.setBoundingBox(BoundingBox(pMin, pMax)); + + if (node.hasBoundingCone()) + { + // update leaf node's bounding cone + float3 axis = float3(0.0, 0.0, 0.0); + float3 centroid = 0.5 * (pMin + pMax); + float halfAngle = 0.0; + float radius = 0.0; + bool anySilhouettes = false; + bool silhouettesHaveTwoAdjacentFaces = true; + uint nSilhouettes = node.getNumSilhouettes(); + + for (uint p = 0; p < nSilhouettes; p++) + { + uint silhouetteIndex = node.getSilhouetteOffset() + p; + S silhouette = silhouettes[silhouetteIndex]; + axis += silhouette.getNormal(0); + axis += silhouette.getNormal(1); + radius = max(radius, length(silhouette.getCentroid() - centroid)); + silhouettesHaveTwoAdjacentFaces = silhouettesHaveTwoAdjacentFaces && + silhouette.hasTwoAdjacentFaces(); + anySilhouettes = true; + } + + if (!anySilhouettes) + { + halfAngle = -M_PI; + } + else if (!silhouettesHaveTwoAdjacentFaces) + { + halfAngle = M_PI; + } + else + { + float axisNorm = length(axis); + if (axisNorm > FLT_EPSILON) + { + axis /= axisNorm; + + for (uint p = 0; p < nSilhouettes; p++) + { + uint silhouetteIndex = node.getSilhouetteOffset() + p; + for (uint k = 0; k < 2; k++) + { + float3 n = silhouettes[silhouetteIndex].getNormal(k); + float angle = acos(max(-1.0, min(1.0, dot(axis, n)))); + halfAngle = max(halfAngle, angle); + } + } + } + } + + node.setBoundingCone(BoundingCone(axis, halfAngle, radius)); + } + } + + // updates the bounding volume of an aggregate internal node + [mutating] + internal void refitInternalNode(uint nodeIndex) + { + // update internal node's bounding box + N node = nodes[nodeIndex]; + uint leftNodeIndex = nodeIndex + 1; + uint rightNodeIndex = nodeIndex + node.getRightChildOffset(); + N leftNode = nodes[leftNodeIndex]; + N rightNode = nodes[rightNodeIndex]; + + BoundingBox leftBox = leftNode.getBoundingBox(); + BoundingBox rightBox = rightNode.getBoundingBox(); + BoundingBox mergedBox = mergeBoundingBoxes(leftBox, rightBox); + node.setBoundingBox(mergedBox); + + if (node.hasBoundingCone()) + { + // update internal node's bounding cone + BoundingCone leftCone = leftNode.getBoundingCone(); + BoundingCone rightCone = rightNode.getBoundingCone(); + BoundingCone mergedCone = mergeBoundingCones(leftCone, rightCone, + leftBox.getCentroid(), + rightBox.getCentroid(), + mergedBox.getCentroid()); + node.setBoundingCone(mergedCone); + } + } + + // updates the bounding volume of an aggregate node + // NOTE: assumes node indices are provided in bottom-up order + [mutating] + public void refit(uint nodeIndex) + { + if (nodes[nodeIndex].isLeaf()) + { + refitLeafNode(nodeIndex); + } + else + { + refitInternalNode(nodeIndex); + } + } + + // intersects aggregate geometry with ray + public bool intersect(inout Ray r, bool checkForOcclusion, inout Interaction i) + { + TraversalStack traversalStack[FCPW_BVH_MAX_DEPTH]; + float4 distToChildNodes = float4(0.0, 0.0, 0.0, 0.0); + BoundingBox rootBox = nodes[0].getBoundingBox(); + bool didIntersect = false; + + if (rootBox.intersect(r, distToChildNodes[0], distToChildNodes[1])) + { + traversalStack[0].node = 0; + traversalStack[0].distance = distToChildNodes[0]; + int stackPtr = 0; + + while (stackPtr >= 0) + { + // pop off the next node to work on + uint currentNodeIndex = traversalStack[stackPtr].node; + float currentDist = traversalStack[stackPtr].distance; + stackPtr--; + + // if this node is further than the closest found intersection, continue + if (currentDist > r.tMax) + { + continue; + } + + N node = nodes[currentNodeIndex]; + if (node.isLeaf()) + { + // intersect primitives in leaf node + uint nPrimitives = node.getNumPrimitives(); + for (uint p = 0; p < nPrimitives; p++) + { + Interaction c; + uint primitiveIndex = node.getPrimitiveOffset() + p; + bool didIntersectPrimitive = primitives[primitiveIndex].intersect(r, checkForOcclusion, c); + + if (didIntersectPrimitive) + { + if (checkForOcclusion) + { + i.index = c.index; + return true; + } + + didIntersect = true; + r.tMax = min(r.tMax, c.d); + i = c; + } + } + } + else + { + // intersect child nodes + uint leftNodeIndex = currentNodeIndex + 1; + BoundingBox leftBox = nodes[leftNodeIndex].getBoundingBox(); + bool didIntersectLeft = leftBox.intersect(r, distToChildNodes[0], distToChildNodes[1]); + + uint rightNodeIndex = currentNodeIndex + node.getRightChildOffset(); + BoundingBox rightBox = nodes[rightNodeIndex].getBoundingBox(); + bool didIntersectRight = rightBox.intersect(r, distToChildNodes[2], distToChildNodes[3]); + + // which nodes did we intersect? + if (didIntersectLeft && didIntersectRight) + { + // assume that the left child is closer + uint closer = leftNodeIndex; + uint other = rightNodeIndex; + + // ... if the right child was actually closer, swap the relavent values + if (distToChildNodes[2] < distToChildNodes[0]) + { + float tmpDist = distToChildNodes[0]; + distToChildNodes[0] = distToChildNodes[2]; + distToChildNodes[2] = tmpDist; + + uint tmpNodeIndex = closer; + closer = other; + other = tmpNodeIndex; + } + + // it's possible that the nearest primitive is still in the other node, + // but we'll check the farther-away node later. + + // push the further node first, then the closer node + stackPtr++; + traversalStack[stackPtr].node = other; + traversalStack[stackPtr].distance = distToChildNodes[2]; + + stackPtr++; + traversalStack[stackPtr].node = closer; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (didIntersectLeft) + { + stackPtr++; + traversalStack[stackPtr].node = leftNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (didIntersectRight) + { + stackPtr++; + traversalStack[stackPtr].node = rightNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[2]; + } + } + } + } + + return didIntersect; + } + + // intersects aggregate geometry with sphere + public bool intersect<T : IBranchTraversalWeight>(BoundingSphere s, float3 randNums, + T branchTraversalWeight, + inout Interaction i) + { + float4 distToChildNodes = float4(0.0, 0.0, 0.0, 0.0); + BoundingBox rootBox = nodes[0].getBoundingBox(); + uint currentNodeIndex = 0; + uint selectedPrimitiveIndex = UINT_MAX; + bool didIntersect = false; + + if (rootBox.overlap(s, distToChildNodes[0], distToChildNodes[1])) + { + float maxDistToChildNode = distToChildNodes[1]; + float traversalPdf = 1.0; + float u = randNums[0]; + int stackPtr = 0; + + while (stackPtr >= 0) + { + // pop off the next node to work on + stackPtr--; + + N node = nodes[currentNodeIndex]; + if (node.isLeaf()) + { + // probabilistically select a primitive + float totalPrimitiveWeight = 0.0; + uint nPrimitives = node.getNumPrimitives(); + for (uint p = 0; p < nPrimitives; p++) + { + Interaction c; + bool didIntersectPrimitive = false; + uint primitiveIndex = node.getPrimitiveOffset() + p; + P primitive = primitives[primitiveIndex]; + + if (maxDistToChildNode <= s.r2) + { + didIntersectPrimitive = true; + c.d = primitive.getSurfaceArea(); + c.index = primitive.getIndex(); + } + else + { + didIntersectPrimitive = primitive.intersect(s, c); + } + + if (didIntersectPrimitive) + { + didIntersect = true; + totalPrimitiveWeight += c.d; + float selectionProb = c.d / totalPrimitiveWeight; + + if (u < selectionProb) + { + u = u / selectionProb; // rescale to [0,1) + i = c; + i.d *= traversalPdf; + selectedPrimitiveIndex = primitiveIndex; + } + else + { + u = (u - selectionProb) / (1.0 - selectionProb); + } + } + } + + if (totalPrimitiveWeight > 0.0) + { + i.d /= totalPrimitiveWeight; + } + } + else + { + // probabilistically select one child node to traverse + uint leftNodeIndex = currentNodeIndex + 1; + BoundingBox leftBox = nodes[leftNodeIndex].getBoundingBox(); + bool overlapsLeft = leftBox.overlap(s, distToChildNodes[0], distToChildNodes[1]); + float weightLeft = overlapsLeft ? 1.0 : 0.0; + if (weightLeft > 0.0) + { + float3 u = s.c - leftBox.getCentroid(); + weightLeft *= branchTraversalWeight.compute(dot(u, u)); + } + + uint rightNodeIndex = currentNodeIndex + node.getRightChildOffset(); + BoundingBox rightBox = nodes[rightNodeIndex].getBoundingBox(); + bool overlapsRight = rightBox.overlap(s, distToChildNodes[2], distToChildNodes[3]); + float weightRight = overlapsRight ? 1.0 : 0.0; + if (weightRight > 0.0) + { + float3 u = s.c - rightBox.getCentroid(); + weightRight *= branchTraversalWeight.compute(dot(u, u)); + } + + float totalTraversalWeight = weightLeft + weightRight; + if (totalTraversalWeight > 0.0) + { + stackPtr++; + float traversalProbLeft = weightLeft / totalTraversalWeight; + float traversalProbRight = 1.0 - traversalProbLeft; + + if (u < traversalProbLeft) + { + u = u / traversalProbLeft; // rescale to [0,1) + currentNodeIndex = leftNodeIndex; + traversalPdf *= traversalProbLeft; + maxDistToChildNode = distToChildNodes[1]; + } + else + { + u = (u - traversalProbLeft) / traversalProbRight; // rescale to [0,1) + currentNodeIndex = rightNodeIndex; + traversalPdf *= traversalProbRight; + maxDistToChildNode = distToChildNodes[3]; + } + } + } + } + } + + if (didIntersect) + { + if (i.index == UINT_MAX || selectedPrimitiveIndex == UINT_MAX) + { + didIntersect = false; + } + else + { + // sample a point on the selected geometric primitive + float samplingPdf = primitives[selectedPrimitiveIndex].samplePoint(randNums.yz, i.uv, i.p, i.n); + i.d *= samplingPdf; + } + } + + return didIntersect; + } + + // finds closest point on aggregate geometry from sphere center + public bool findClosestPoint(inout BoundingSphere s, inout Interaction i, + bool recordNormal = false) + { + TraversalStack traversalStack[FCPW_BVH_MAX_DEPTH]; + float4 distToChildNodes = float4(0.0, 0.0, 0.0, 0.0); + BoundingBox rootBox = nodes[0].getBoundingBox(); + bool notFound = true; + + if (rootBox.overlap(s, distToChildNodes[0], distToChildNodes[1])) + { + s.r2 = min(s.r2, distToChildNodes[1]); + traversalStack[0].node = 0; + traversalStack[0].distance = distToChildNodes[0]; + int stackPtr = 0; + + while (stackPtr >= 0) + { + // pop off the next node to work on + uint currentNodeIndex = traversalStack[stackPtr].node; + float currentDist = traversalStack[stackPtr].distance; + stackPtr--; + + // if this node is further than the closest found primitive, continue + if (currentDist > s.r2) + { + continue; + } + + N node = nodes[currentNodeIndex]; + if (node.isLeaf()) + { + // compute distance to primitives in leaf node + uint nPrimitives = node.getNumPrimitives(); + for (uint p = 0; p < nPrimitives; p++) + { + Interaction c; + uint primitiveIndex = node.getPrimitiveOffset() + p; + bool found = primitives[primitiveIndex].findClosestPoint(s, c); + + // keep the closest point only + if (found) + { + notFound = false; + s.r2 = min(s.r2, c.d * c.d); + i = c; + } + } + } + else + { + // find distance to child nodes + uint leftNodeIndex = currentNodeIndex + 1; + BoundingBox leftBox = nodes[leftNodeIndex].getBoundingBox(); + bool overlapsLeft = leftBox.overlap(s, distToChildNodes[0], distToChildNodes[1]); + s.r2 = min(s.r2, distToChildNodes[1]); + + uint rightNodeIndex = currentNodeIndex + node.getRightChildOffset(); + BoundingBox rightBox = nodes[rightNodeIndex].getBoundingBox(); + bool overlapsRight = rightBox.overlap(s, distToChildNodes[2], distToChildNodes[3]); + s.r2 = min(s.r2, distToChildNodes[3]); + + // which nodes do we overlap? + if (overlapsLeft && overlapsRight) + { + // assume that the left child is closer + uint closer = leftNodeIndex; + uint other = rightNodeIndex; + + // ... if the right child was actually closer, swap the relavent values + if (distToChildNodes[0] == 0.0 && distToChildNodes[2] == 0.0) + { + if (distToChildNodes[3] < distToChildNodes[1]) + { + uint tmpNodeIndex = closer; + closer = other; + other = tmpNodeIndex; + } + } + else if (distToChildNodes[2] < distToChildNodes[0]) + { + float tmpDist = distToChildNodes[0]; + distToChildNodes[0] = distToChildNodes[2]; + distToChildNodes[2] = tmpDist; + + uint tmpNodeIndex = closer; + closer = other; + other = tmpNodeIndex; + } + + // it's possible that the nearest primitive is still in the other node, + // but we'll check the farther-away node later. + + // push the further node first, then the closer node + stackPtr++; + traversalStack[stackPtr].node = other; + traversalStack[stackPtr].distance = distToChildNodes[2]; + + stackPtr++; + traversalStack[stackPtr].node = closer; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (overlapsLeft) + { + stackPtr++; + traversalStack[stackPtr].node = leftNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (overlapsRight) + { + stackPtr++; + traversalStack[stackPtr].node = rightNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[2]; + } + } + } + } + + if (!notFound && recordNormal) + { + i.n = primitives[i.index].getNormal(); + } + + return !notFound; + } + + // finds closest silhouette point on aggregate geometry from sphere center + public bool findClosestSilhouettePoint(inout BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i) + { + if (squaredMinRadius >= s.r2) + { + return false; + } + + TraversalStack traversalStack[FCPW_BVH_MAX_DEPTH]; + float2 distToChildNodes = float2(0.0, 0.0); + BoundingBox rootBox = nodes[0].getBoundingBox(); + bool notFound = true; + + if (rootBox.overlap(s, distToChildNodes[0])) + { + traversalStack[0].node = 0; + traversalStack[0].distance = distToChildNodes[0]; + int stackPtr = 0; + + while (stackPtr >= 0) + { + // pop off the next node to work on + uint currentNodeIndex = traversalStack[stackPtr].node; + float currentDist = traversalStack[stackPtr].distance; + stackPtr--; + + // if this node is further than the closest found primitive, continue + if (currentDist > s.r2) + { + continue; + } + + N node = nodes[currentNodeIndex]; + if (node.isLeaf()) + { + // compute distance to silhouettes in leaf node + uint nSilhouettes = node.getNumSilhouettes(); + for (uint p = 0; p < nSilhouettes; p++) + { + uint silhouetteIndex = node.getSilhouetteOffset() + p; + S silhouette = silhouettes[silhouetteIndex]; + if (silhouette.getIndex() == i.index) + { + // silhouette has already been checked + continue; + } + + Interaction c; + bool found = silhouette.findClosestSilhouettePoint( + s, flipNormalOrientation, squaredMinRadius, precision, c); + + // keep the closest silhouette point + if (found) + { + notFound = false; + s.r2 = min(s.r2, c.d * c.d); + i = c; + + if (squaredMinRadius >= s.r2) + { + break; + } + } + } + } + else + { + // find distance to child nodes + // NOTE: Slang does not support short-circuiting with the && operator, hence the clunky code + uint leftNodeIndex = currentNodeIndex + 1; + N leftNode = nodes[leftNodeIndex]; + BoundingCone leftCone = leftNode.getBoundingCone(); + bool overlapsLeft = leftCone.isValid(); + if (overlapsLeft) + { + BoundingBox leftBox = leftNode.getBoundingBox(); + overlapsLeft = leftBox.overlap(s, distToChildNodes[0]); + if (overlapsLeft) + { + float minAngleRange, maxAngleRange; + overlapsLeft = leftCone.overlap(s.c, leftBox, distToChildNodes[0], + minAngleRange, maxAngleRange); + } + } + + uint rightNodeIndex = currentNodeIndex + node.getRightChildOffset(); + N rightNode = nodes[rightNodeIndex]; + BoundingCone rightCone = rightNode.getBoundingCone(); + bool overlapsRight = rightCone.isValid(); + if (overlapsRight) + { + BoundingBox rightBox = rightNode.getBoundingBox(); + overlapsRight = rightBox.overlap(s, distToChildNodes[1]); + if (overlapsRight) + { + float minAngleRange, maxAngleRange; + overlapsRight = rightCone.overlap(s.c, rightBox, distToChildNodes[1], + minAngleRange, maxAngleRange); + } + } + + // which nodes do we overlap? + if (overlapsLeft && overlapsRight) + { + // assume that the left child is closer + uint closer = leftNodeIndex; + uint other = rightNodeIndex; + + // ... if the right child was actually closer, swap the relavent values + if (distToChildNodes[1] < distToChildNodes[0]) + { + float tmpDist = distToChildNodes[0]; + distToChildNodes[0] = distToChildNodes[1]; + distToChildNodes[1] = tmpDist; + + uint tmpNodeIndex = closer; + closer = other; + other = tmpNodeIndex; + } + + // it's possible that the nearest primitive is still in the other node, + // but we'll check the farther-away node later. + + // push the further node first, then the closer node + stackPtr++; + traversalStack[stackPtr].node = other; + traversalStack[stackPtr].distance = distToChildNodes[1]; + + stackPtr++; + traversalStack[stackPtr].node = closer; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (overlapsLeft) + { + stackPtr++; + traversalStack[stackPtr].node = leftNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[0]; + } + else if (overlapsRight) + { + stackPtr++; + traversalStack[stackPtr].node = rightNodeIndex; + traversalStack[stackPtr].distance = distToChildNodes[1]; + } + } + } + } + + return !notFound; + } +}; diff --git a/tests/fcpw/fcpw.slang b/tests/fcpw/fcpw.slang new file mode 100644 index 000000000..b0c4b33ee --- /dev/null +++ b/tests/fcpw/fcpw.slang @@ -0,0 +1,2 @@ +module fcpw; +__include bvh;
\ No newline at end of file diff --git a/tests/fcpw/geometry.slang b/tests/fcpw/geometry.slang new file mode 100644 index 000000000..9f62b22c3 --- /dev/null +++ b/tests/fcpw/geometry.slang @@ -0,0 +1,708 @@ +implementing fcpw; +__include ray; +__include math_constants; +__include interaction; +__include bounding_volumes; + +public interface IPrimitive +{ + // returns the bounding box of the primitive + BoundingBox getBoundingBox(); + + // returns the centroid of the primitive + float3 getCentroid(); + + // returns the normal of the primitive + float3 getNormal(); + + // returns the surface area of the primitive + float getSurfaceArea(); + + // intersects primitive with ray + bool intersect(Ray r, bool checkForOcclusion, inout Interaction i); + + // intersects primitive with sphere + bool intersect(BoundingSphere s, inout Interaction i); + + // finds closest point on primitive from sphere center + bool findClosestPoint(BoundingSphere s, inout Interaction i); + + // samples point on primitive and returns sampling pdf + float samplePoint(float2 randNums, out float2 uv, out float3 p, out float3 n); + + // returns the index of the primitive + uint getIndex(); +}; + +public bool intersectLineSegment(float3 pa, float3 pb, + float3 ro, float3 rd, float rtMax, bool checkForOcclusion, + inout float3 p, inout float3 n, inout float2 uv, inout float d) +{ + float3 u = pa - ro; + float3 v = pb - pa; + + // return if line segment and ray are parallel + float dv = cross(rd, v)[2]; + if (abs(dv) <= FLT_EPSILON) + { + return false; + } + + // solve ro + t*rd = pa + s*(pb - pa) for t >= 0 && 0 <= s <= 1 + // s = (u x rd)/(rd x v) + float ud = cross(u, rd)[2]; + float s = ud / dv; + + if (s >= 0.0 && s <= 1.0) + { + // t = (u x v)/(rd x v) + float t = cross(u, v)[2] / dv; + + if (t >= 0.0 && t <= rtMax) + { + if (checkForOcclusion) + { + return true; + } + + p = pa + s * v; + n = normalize(float3(v[1], -v[0], 0.0)); + uv = float2(s, 0.0); + d = t; + return true; + } + } + + return false; +} + +public float findClosestPointLineSegment(float3 pa, float3 pb, float3 x, out float3 p, out float t) +{ + float3 u = pb - pa; + float3 v = x - pa; + + float c1 = dot(u, v); + if (c1 <= 0.0) + { + t = 0.0; + p = pa; + return length(x - p); + } + + float c2 = dot(u, u); + if (c2 <= c1) + { + t = 1.0; + p = pb; + return length(x - p); + } + + t = c1 / c2; + p = pa + u * t; + return length(x - p); +} + +public struct LineSegment : IPrimitive +{ + public float3 pa; + public float3 pb; + public uint index; + + // returns the bounding box of the primitive + public BoundingBox getBoundingBox() + { + float3 epsilon = float3(FLT_EPSILON, FLT_EPSILON, 0.0); + return BoundingBox(min(pa, pb) - epsilon, max(pa, pb) + epsilon); + } + + // returns the centroid of the primitive + public float3 getCentroid() + { + return 0.5 * (pa + pb); + } + + // returns the normal of the primitive + public float3 getNormal() + { + float3 s = pb - pa; + float3 n = float3(s.y, -s.x, 0.0); + + return normalize(n); + } + + // returns the surface area of the primitive + public float getSurfaceArea() + { + return length(pb - pa); + } + + // intersects primitive with ray + // NOTE: specialized to 2D (z coordinate == 0) + public bool intersect(Ray r, bool checkForOcclusion, inout Interaction i) + { + bool didIntersect = intersectLineSegment(pa, pb, r.o, r.d, r.tMax, checkForOcclusion, i.p, i.n, i.uv, i.d); + if (didIntersect) + { + i.index = index; + return true; + } + + return false; + } + + // intersects primitive with sphere + public bool intersect(BoundingSphere s, inout Interaction i) + { + float d = findClosestPointLineSegment(pa, pb, s.c, i.p, i.uv[0]); + if (d * d <= s.r2) + { + i.d = getSurfaceArea(); + i.index = index; + return true; + } + + return false; + } + + // finds closest point on primitive from sphere center + public bool findClosestPoint(BoundingSphere s, inout Interaction i) + { + float d = findClosestPointLineSegment(pa, pb, s.c, i.p, i.uv[0]); + if (d * d <= s.r2) + { + i.uv[1] = 0.0; + i.d = d; + i.index = index; + return true; + } + + return false; + } + + // samples point on primitive and returns sampling pdf + public float samplePoint(float2 randNums, out float2 uv, out float3 p, out float3 n) + { + float3 s = pb - pa; + float area = length(s); + float u = randNums[0]; + uv = float2(u, 0.0); + p = pa + u * s; + n = float3(s[1], -s[0], 0.0) / area; + + return 1.0 / area; + } + + // returns the index of the primitive + public uint getIndex() + { + return index; + } +}; + +public bool intersectTriangle(float3 pa, float3 pb, float3 pc, + float3 ro, float3 rd, float rtMax, bool checkForOcclusion, + inout float3 p, inout float3 n, inout float2 uv, inout float d) +{ + // Möller–Trumbore intersection algorithm + float3 v1 = pb - pa; + float3 v2 = pc - pa; + float3 q = cross(rd, v2); + float det = dot(v1, q); + + // ray and triangle are parallel if det is close to 0 + if (abs(det) <= FLT_EPSILON) + { + return false; + } + float invDet = 1.0 / det; + + float3 r = ro - pa; + float v = dot(r, q) * invDet; + if (v < 0.0 || v > 1.0) + { + return false; + } + + float3 s = cross(r, v1); + float w = dot(rd, s) * invDet; + if (w < 0.0 || v + w > 1.0) + { + return false; + } + + float t = dot(v2, s) * invDet; + if (t >= 0.0 && t <= rtMax) + { + if (checkForOcclusion) + { + return true; + } + + p = pa + v1 * v + v2 * w; + n = normalize(cross(v1, v2)); + uv = float2(1.0 - v - w, v); + d = t; + return true; + } + + return false; +} + +public float findClosestPointTriangle(float3 pa, float3 pb, float3 pc, float3 x, out float3 p, out float2 t) +{ + // source: real time collision detection + // check if x in vertex region outside pa + float3 ab = pb - pa; + float3 ac = pc - pa; + float3 ax = x - pa; + float d1 = dot(ab, ax); + float d2 = dot(ac, ax); + if (d1 <= 0.0 && d2 <= 0.0) + { + // barycentric coordinates (1, 0, 0) + t = float2(1.0, 0.0); + p = pa; + return length(x - p); + } + + // check if x in vertex region outside pb + float3 bx = x - pb; + float d3 = dot(ab, bx); + float d4 = dot(ac, bx); + if (d3 >= 0.0 && d4 <= d3) + { + // barycentric coordinates (0, 1, 0) + t = float2(0.0, 1.0); + p = pb; + return length(x - p); + } + + // check if x in vertex region outside pc + float3 cx = x - pc; + float d5 = dot(ab, cx); + float d6 = dot(ac, cx); + if (d6 >= 0.0 && d5 <= d6) + { + // barycentric coordinates (0, 0, 1) + t = float2(0.0, 0.0); + p = pc; + return length(x - p); + } + + // check if x in edge region of ab, if so return projection of x onto ab + float vc = d1 * d4 - d3 * d2; + if (vc <= 0.0 && d1 >= 0.0 && d3 <= 0.0) + { + // barycentric coordinates (1 - v, v, 0) + float v = d1 / (d1 - d3); + t = float2(1.0 - v, v); + p = pa + ab * v; + return length(x - p); + } + + // check if x in edge region of ac, if so return projection of x onto ac + float vb = d5 * d2 - d1 * d6; + if (vb <= 0.0 && d2 >= 0.0 && d6 <= 0.0) + { + // barycentric coordinates (1 - w, 0, w) + float w = d2 / (d2 - d6); + t = float2(1.0 - w, 0.0); + p = pa + ac * w; + return length(x - p); + } + + // check if x in edge region of bc, if so return projection of x onto bc + float va = d3 * d6 - d5 * d4; + if (va <= 0.0 && (d4 - d3) >= 0.0 && (d5 - d6) >= 0.0) + { + // barycentric coordinates (0, 1 - w, w) + float w = (d4 - d3) / ((d4 - d3) + (d5 - d6)); + t = float2(0.0, 1.0 - w); + p = pb + (pc - pb) * w; + return length(x - p); + } + + // x inside face region. Compute p through its barycentric coordinates (u, v, w) + float denom = 1.0 / (va + vb + vc); + float v = vb * denom; + float w = vc * denom; + t = float2(1.0 - v - w, v); + p = pa + ab * v + ac * w; //= u*a + v*b + w*c, u = va*denom = 1.0f - v - w + return length(x - p); +} + +public struct Triangle : IPrimitive +{ + public float3 pa; + public float3 pb; + public float3 pc; + public uint index; + + // returns the bounding box of the primitive + public BoundingBox getBoundingBox() + { + float3 epsilon = float3(FLT_EPSILON, FLT_EPSILON, FLT_EPSILON); + return BoundingBox(min(min(pa, pb), pc) - epsilon, max(max(pa, pb), pc) + epsilon); + } + + // returns the centroid of the primitive + public float3 getCentroid() + { + return (pa + pb + pc) / 3.0; + } + + // returns the surface area of the primitive + public float getSurfaceArea() + { + return 0.5 * length(cross(pb - pa, pc - pa)); + } + + // returns the normal of the primitive + public float3 getNormal() + { + float3 n = cross(pb - pa, pc - pa); + + return normalize(n); + } + + // intersects primitive with ray + public bool intersect(Ray r, bool checkForOcclusion, inout Interaction i) + { + bool didIntersect = intersectTriangle(pa, pb, pc, r.o, r.d, r.tMax, checkForOcclusion, i.p, i.n, i.uv, i.d); + if (didIntersect) + { + i.index = index; + return true; + } + + return false; + } + + // intersects primitive with sphere + public bool intersect(BoundingSphere s, inout Interaction i) + { + float d = findClosestPointTriangle(pa, pb, pc, s.c, i.p, i.uv); + if (d * d <= s.r2) + { + i.d = getSurfaceArea(); + i.index = index; + return true; + } + + return false; + } + + // finds closest point on primitive from sphere center + public bool findClosestPoint(BoundingSphere s, inout Interaction i) + { + float d = findClosestPointTriangle(pa, pb, pc, s.c, i.p, i.uv); + if (d * d <= s.r2) + { + i.d = d; + i.index = index; + return true; + } + + return false; + } + + // samples point on primitive and returns sampling pdf + public float samplePoint(float2 randNums, out float2 uv, out float3 p, out float3 n) + { + n = cross(pb - pa, pc - pa); + float area = length(n); + float u1 = sqrt(randNums[0]); + float u2 = randNums[1]; + float u = 1.0 - u1; + float v = u2 * u1; + float w = 1.0 - u - v; + uv = float2(u, v); + p = pa * u + pb * v + pc * w; + n /= area; + + return 2.0 / area; + } + + // returns the index of the primitive + public uint getIndex() + { + return index; + } +}; + +public interface ISilhouette +{ + // returns the centroid of the silhouette + float3 getCentroid(); + + // returns whether silhouette has two adjacent faces + bool hasTwoAdjacentFaces(); + + // returns normal of adjacent face + float3 getNormal(uint fIndex); + + // finds closest silhouette point on primitive from sphere center + bool findClosestSilhouettePoint(BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i); + + // returns the index of the silhouette + uint getIndex(); +}; + +public struct NoSilhouette : ISilhouette +{ + public uint index; + + // returns the centroid of the silhouette + public float3 getCentroid() + { + return float3(0.0, 0.0, 0.0); + } + + // returns whether silhouette has two adjacent faces + public bool hasTwoAdjacentFaces() + { + return false; + } + + // returns normal of adjacent face + public float3 getNormal(uint fIndex) + { + return float3(0.0, 0.0, 0.0); + } + + // finds closest silhouette point on primitive from sphere center + public bool findClosestSilhouettePoint(BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i) + { + return false; + } + + // returns the index of the silhouette + public uint getIndex() + { + return UINT_MAX; + } +}; + +public bool isSilhouetteVertex(float3 n0, float3 n1, float3 viewDir, float d, bool flipNormalOrientation, float precision) +{ + float sign = flipNormalOrientation ? 1.0 : -1.0; + + // vertex is a silhouette point if it is concave and the query point lies on the vertex + if (d <= precision) + { + float det = n0.x * n1.y - n1.x * n0.y; + return sign * det > precision; + } + + // vertex is a silhouette point if the query point lies on the halfplane + // defined by an adjacent line segment and the other segment is backfacing + float3 viewDirUnit = viewDir / d; + float dot0 = dot(viewDirUnit, n0); + float dot1 = dot(viewDirUnit, n1); + + bool isZeroDot0 = abs(dot0) <= precision; + if (isZeroDot0) + { + return sign * dot1 > precision; + } + + bool isZeroDot1 = abs(dot1) <= precision; + if (isZeroDot1) + { + return sign * dot0 > precision; + } + + // vertex is a silhouette point if an adjacent line segment is frontfacing + // w.r.t. the query point and the other segment is backfacing + return dot0 * dot1 < 0.0; +} + +public struct Vertex : ISilhouette +{ + public float3 p; + public float3 n0; + public float3 n1; + public uint index; + public uint hasOneAdjacentFace; + + // returns the centroid of the silhouette + public float3 getCentroid() + { + return p; + } + + // returns whether silhouette has two adjacent faces + public bool hasTwoAdjacentFaces() + { + return hasOneAdjacentFace == 0; + } + + // returns normal of adjacent face + public float3 getNormal(uint fIndex) + { + if (fIndex == 0) + { + return n0; + } + + return n1; + } + + // finds closest silhouette point on primitive from sphere center + public bool findClosestSilhouettePoint(BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i) + { + if (squaredMinRadius >= s.r2) + { + return false; + } + + // compute view direction + float3 viewDir = s.c - p; + float d = length(viewDir); + if (d * d > s.r2) + { + return false; + } + + // check if vertex is a silhouette point from view direction + bool process = hasOneAdjacentFace == 1 ? true : false; + if (!process) + { + process = isSilhouetteVertex(n0, n1, viewDir, d, flipNormalOrientation, precision); + } + + if (process && d * d <= s.r2) + { + i.p = p; + i.uv = float2(0.0, 0.0); + i.d = d; + i.index = index; + return true; + } + + return false; + } + + // returns the index of the silhouette + public uint getIndex() + { + return index; + } +}; + +public bool isSilhouetteEdge(float3 pa, float3 pb, float3 n0, float3 n1, float3 viewDir, + float d, bool flipNormalOrientation, float precision) +{ + float sign = flipNormalOrientation ? 1.0 : -1.0; + + // edge is a silhouette if it is concave and the query point lies on the edge + if (d <= precision) + { + float3 edgeDir = normalize(pb - pa); + float signedDihedralAngle = atan2(dot(edgeDir, cross(n0, n1)), dot(n0, n1)); + return sign * signedDihedralAngle > precision; + } + + // edge is a silhouette if the query point lies on the halfplane defined + // by an adjacent triangle and the other triangle is backfacing + float3 viewDirUnit = viewDir / d; + float dot0 = dot(viewDirUnit, n0); + float dot1 = dot(viewDirUnit, n1); + + bool isZeroDot0 = abs(dot0) <= precision; + if (isZeroDot0) + { + return sign * dot1 > precision; + } + + bool isZeroDot1 = abs(dot1) <= precision; + if (isZeroDot1) + { + return sign * dot0 > precision; + } + + // edge is a silhouette if an adjacent triangle is frontfacing w.r.t. the + // query point and the other triangle is backfacing + return dot0 * dot1 < 0.0; +} + +public struct Edge : ISilhouette +{ + public float3 pa; + public float3 pb; + public float3 n0; + public float3 n1; + public uint index; + public uint hasOneAdjacentFace; + + // returns the centroid of the silhouette + public float3 getCentroid() + { + return 0.5 * (pa + pb); + } + + // returns whether silhouette has two adjacent faces + public bool hasTwoAdjacentFaces() + { + return hasOneAdjacentFace == 0; + } + + // returns normal of adjacent face + public float3 getNormal(uint fIndex) + { + if (fIndex == 0) + { + return n0; + } + + return n1; + } + + // finds closest silhouette point on primitive from sphere center + public bool findClosestSilhouettePoint(BoundingSphere s, bool flipNormalOrientation, + float squaredMinRadius, float precision, + inout Interaction i) + { + if (squaredMinRadius >= s.r2) + { + return false; + } + + // compute view direction + float d = findClosestPointLineSegment(pa, pb, s.c, i.p, i.uv[0]); + if (d * d > s.r2) + { + return false; + } + + // check if edge is a silhouette from view direction + bool process = hasOneAdjacentFace == 1 ? true : false; + if (!process) + { + float3 viewDir = s.c - i.p; + process = isSilhouetteEdge(pa, pb, n0, n1, viewDir, d, flipNormalOrientation, precision); + } + + if (process && d * d <= s.r2) + { + i.uv[1] = 0.0; + i.d = d; + i.index = index; + return true; + } + + return false; + } + + // returns the index of the silhouette + public uint getIndex() + { + return index; + } +}; diff --git a/tests/fcpw/interaction.slang b/tests/fcpw/interaction.slang new file mode 100644 index 000000000..3640016eb --- /dev/null +++ b/tests/fcpw/interaction.slang @@ -0,0 +1,21 @@ +implementing fcpw; +__include math_constants; + +public struct Interaction +{ + public float3 p; // interaction point associated with query + public float3 n; // normal at interaction point + public float2 uv; // uv coordinates of interaction point + public float d; // distance to interaction point + public uint index; // index of primitive/silhouette associated with interaction point + + // constructor + public __init() + { + p = float3(0.0, 0.0, 0.0); + n = float3(0.0, 0.0, 0.0); + uv = float2(0.0, 0.0); + d = FLT_MAX; + index = UINT_MAX; + } +}; diff --git a/tests/fcpw/math-constants.slang b/tests/fcpw/math-constants.slang new file mode 100644 index 000000000..e9bf4b3db --- /dev/null +++ b/tests/fcpw/math-constants.slang @@ -0,0 +1,7 @@ +implementing fcpw; + +internal static const float M_PI = 3.14159265358979323846; // pi +internal static const float M_PI_2 = 1.57079632679489661923; // pi/2 +internal static const float FLT_MAX = 3.402823466e+38F; // max float value +internal static const float FLT_EPSILON = 1.192092896e-07F; // smallest float value such that 1.0+FLT_EPSILON != 1.0 +internal static const uint UINT_MAX = 4294967295; // max unsigned int value
\ No newline at end of file diff --git a/tests/fcpw/ray.slang b/tests/fcpw/ray.slang new file mode 100644 index 000000000..f12664ed7 --- /dev/null +++ b/tests/fcpw/ray.slang @@ -0,0 +1,18 @@ +implementing fcpw; + +public struct Ray +{ + public float3 o; // ray origin + public float3 d; // ray direction + public float3 dInv; // 1 over ray direction (coordinate-wise) + public float tMax; // max ray distance + + // constructor + public __init(float3 o_, float3 d_, float tMax_) + { + o = o_; + d = d_; + dInv = float3(1.0, 1.0, 1.0) / d_; + tMax = tMax_; + } +}; diff --git a/tests/fcpw/transform.slang b/tests/fcpw/transform.slang new file mode 100644 index 000000000..7e69f7227 --- /dev/null +++ b/tests/fcpw/transform.slang @@ -0,0 +1,56 @@ +implementing fcpw; +__include ray; +__include math_constants; +__include interaction; +__include bounding_volumes; + +public float3x3 extractLinearTransform(float3x4 t) +{ + return float3x3(t[0].xyz, t[1].xyz, t[2].xyz); +} + +public float3 extractTranslation(float3x4 t) +{ + return float3(t[0][3], t[1][3], t[2][3]); +} + +public float3 transformPoint(float3x4 t, float3 p) +{ + return mul(t, float4(p, 1.0)).xyz; +} + +public Ray transformRay(float3x4 t, Ray r) +{ + float3 o = transformPoint(t, r.o); + float3 d = transformPoint(t, r.o + r.d * (r.tMax < FLT_MAX ? r.tMax : 1.0)) - o; + float dNorm = length(d); + + return Ray(o, d / dNorm, r.tMax < FLT_MAX ? dNorm : FLT_MAX); +} + +public BoundingSphere transformSphere(float3x4 t, BoundingSphere s) +{ + float3 c = transformPoint(t, s.c); + float r2 = FLT_MAX; + if (s.r2 < FLT_MAX) + { + float3 d = transformPoint(t, s.c + float3(sqrt(s.r2), 0.0, 0.0)) - c; + r2 = dot(d, d); + } + + return BoundingSphere(c, r2); +} + +public void transformInteraction(float3x4 t, float3x4 tInv, float3 x, + bool overwriteDistance, inout Interaction i) +{ + float3 p = transformPoint(t, i.p); + float3 n = normalize(mul(transpose(extractLinearTransform(tInv)), i.n)); + + i.p = p; + i.n = n; + if (overwriteDistance) + { + i.d = length(p - x); + } +} |
