diff options
| author | Yong He <yonghe@outlook.com> | 2023-09-03 12:56:31 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-09-03 12:56:31 -0700 |
| commit | 1d4b5b6fd2433a10cc7ab87626cb560f54b0acbb (patch) | |
| tree | 6196d519190720fd2968ac7d4b373e3c967d5fe6 | |
| parent | 355bb4287861f96082751042f4e58ff3598b4e5e (diff) | |
Proper lowering of functiosn that returns NonCopyable values. (#3179)
* Proper lowering of functiosn that returns NonCopyable values.
* Fix tests.
* Fix clang errors.
* Fix.
* Fix clang error.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
28 files changed, 884 insertions, 697 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index d56c4ffcd..23815d2e9 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -8286,8 +8286,6 @@ struct HitObject RayDesc Ray, inout payload_t Payload) { - HitObject hitObj; - [__vulkanRayPayload] static payload_t p; @@ -8295,7 +8293,7 @@ struct HitObject p = Payload; __glslTraceRay( - hitObj, + __return_val, AccelerationStructure, RayFlags, // Assumes D3D/VK have some RayFlags values InstanceInclusionMask, // cullMask @@ -8310,8 +8308,6 @@ struct HitObject // Write the payload out Payload = p; - - return hitObj; } /// Executes motion ray traversal (including anyhit and intersection shaders) like TraceRay, but returns the @@ -8329,8 +8325,6 @@ struct HitObject float CurrentTime, inout payload_t Payload) { - HitObject hitObj; - [__vulkanRayPayload] static payload_t p; @@ -8338,7 +8332,7 @@ struct HitObject p = Payload; __glslTraceMotionRay( - hitObj, + __return_val, AccelerationStructure, RayFlags, // Assumes D3D/VK have some RayFlags values InstanceInclusionMask, // cullMask @@ -8354,8 +8348,6 @@ struct HitObject // Write the payload out Payload = p; - - return hitObj; } /// Creates a HitObject representing a hit based on values explicitly passed as arguments, without @@ -8404,14 +8396,13 @@ struct HitObject RayDesc Ray, attr_t attributes) { - HitObject hitObj; - // Save the attributes __ref attr_t attr = __hitObjectAttributes<attr_t>(); attr = attributes; - __glslMakeHit(hitObj, + __glslMakeHit( + __return_val, AccelerationStructure, InstanceIndex, PrimitiveIndex, @@ -8424,8 +8415,6 @@ struct HitObject Ray.Direction, Ray.TMax, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); - - return hitObj; } /// See MakeHit but handles Motion @@ -8444,14 +8433,13 @@ struct HitObject float CurrentTime, attr_t attributes) { - HitObject hitObj; - // Save the attributes __ref attr_t attr = __hitObjectAttributes<attr_t>(); attr = attributes; - __glslMakeMotionHit(hitObj, + __glslMakeMotionHit( + __return_val, AccelerationStructure, InstanceIndex, PrimitiveIndex, @@ -8465,8 +8453,6 @@ struct HitObject Ray.TMax, CurrentTime, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); - - return hitObj; } /// Creates a HitObject representing a hit based on values explicitly passed as arguments, without @@ -8513,13 +8499,12 @@ struct HitObject RayDesc Ray, attr_t attributes) { - HitObject hitObj; - // Save the attributes __ref attr_t attr = __hitObjectAttributes<attr_t>(); attr = attributes; - __glslMakeHitWithIndex(hitObj, + __glslMakeHitWithIndex( + __return_val, AccelerationStructure, InstanceIndex, ///? Same as instanceid ? GeometryIndex, @@ -8531,8 +8516,6 @@ struct HitObject Ray.Direction, Ray.TMax, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); - - return hitObj; } /// See MakeHit but handles Motion @@ -8556,7 +8539,8 @@ struct HitObject __ref attr_t attr = __hitObjectAttributes<attr_t>(); attr = attributes; - __glslMakeMotionHitWithIndex(hitObj, + __glslMakeMotionHitWithIndex( + __return_val, AccelerationStructure, InstanceIndex, ///? Same as instanceid ? GeometryIndex, @@ -8569,8 +8553,6 @@ struct HitObject Ray.TMax, CurrentTime, __hitObjectAttributesLocation(__hitObjectAttributes<attr_t>())); - - return hitObj; } /// Creates a HitObject representing a miss based on values explicitly passed as arguments, without @@ -8588,9 +8570,7 @@ struct HitObject uint MissShaderIndex, RayDesc Ray) { - HitObject hitObj; - __glslMakeMiss(hitObj, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax); - return hitObj; + __glslMakeMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax); } /// See MakeMiss but handles Motion @@ -8602,9 +8582,7 @@ struct HitObject RayDesc Ray, float CurrentTime) { - HitObject hitObj; - __glslMakeMotionMiss(hitObj, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax, CurrentTime); - return hitObj; + __glslMakeMotionMiss(__return_val, MissShaderIndex, Ray.Origin, Ray.TMin, Ray.Direction, Ray.TMax, CurrentTime); } /// Creates a HitObject representing “NOP” (no operation) which is neither a hit nor a miss. Invoking a @@ -8620,9 +8598,7 @@ struct HitObject __specialized_for_target(glsl) static HitObject MakeNop() { - HitObject hitObj; - __glslMakeNop(hitObj); - return hitObj; + __glslMakeNop(__return_val); } /// Invokes closesthit or miss shading for the specified hit object. In case of a NOP HitObject, no @@ -8656,18 +8632,21 @@ struct HitObject [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectIsMissNV($0)") + __glsl_extension(GL_EXT_ray_tracing) bool IsMiss(); /// Returns true if the HitObject encodes a hit, otherwise returns false. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectIsHitNV($0)") + __glsl_extension(GL_EXT_ray_tracing) bool IsHit(); /// Returns true if the HitObject encodes a nop, otherwise returns false. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectIsEmptyNV($0)") + __glsl_extension(GL_EXT_ray_tracing) bool IsNop(); /// Queries ray properties from HitObject. Valid if the hit object represents a hit or a miss. @@ -8686,36 +8665,42 @@ struct HitObject [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetShaderBindingTableRecordIndexNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetShaderTableIndex(); /// Returns the instance index of a hit. Valid if the hit object represents a hit. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetInstanceCustomIndexNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetInstanceIndex(); /// Returns the instance ID of a hit. Valid if the hit object represents a hit. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetInstanceIdNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetInstanceID(); /// Returns the geometry index of a hit. Valid if the hit object represents a hit. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetGeometryIndexNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetGeometryIndex(); /// Returns the primitive index of a hit. Valid if the hit object represents a hit. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetPrimitiveIndexNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetPrimitiveIndex(); /// Returns the hit kind. Valid if the hit object represents a hit. [__requiresNVAPI] __target_intrinsic(hlsl) __target_intrinsic(glsl, "hitObjectGetHitKindNV($0)") + __glsl_extension(GL_EXT_ray_tracing) uint GetHitKind(); /// Returns the attributes of a hit. Valid if the hit object represents a hit or a miss. @@ -8797,6 +8782,7 @@ struct HitObject __glsl_extension(GL_NV_shader_invocation_reorder) __glsl_extension(GL_EXT_ray_tracing) + __glsl_version(460) __target_intrinsic(glsl, "hitObjectRecordMissNV") static void __glslMakeMiss( out HitObject hitObj, @@ -8810,7 +8796,8 @@ struct HitObject __glsl_extension(GL_NV_shader_invocation_reorder) __glsl_extension(GL_EXT_ray_tracing) __glsl_extension(GL_NV_ray_tracing_motion_blur) - __target_intrinsic(glsl, "hitObjectRecordMissNV") + __glsl_version(460) + __target_intrinsic(glsl, "hitObjectRecordMissMotionNV") static void __glslMakeMotionMiss( out HitObject hitObj, uint MissShaderIndex, @@ -8822,7 +8809,7 @@ struct HitObject __glsl_extension(GL_NV_shader_invocation_reorder) __glsl_extension(GL_EXT_ray_tracing) - __target_intrinsic(glsl, "hitObjectRecordEmptyNV($0)") + __target_intrinsic(glsl, "hitObjectRecordEmptyNV") static void __glslMakeNop(out HitObject hitObj); __glsl_extension(GL_NV_shader_invocation_reorder) @@ -8844,6 +8831,7 @@ struct HitObject // "void hitObjectRecordHitWithIndexNV(hitObjectNV, accelerationStructureEXT,int,int,int,uint,uint,vec3,float,vec3,float,int);" __glsl_extension(GL_NV_shader_invocation_reorder) __glsl_extension(GL_EXT_ray_tracing) + __glsl_version(460) __target_intrinsic(glsl, "hitObjectRecordHitWithIndexNV") static void __glslMakeHitWithIndex( out HitObject hitObj, @@ -8929,7 +8917,7 @@ struct HitObject __glsl_extension(GL_NV_shader_invocation_reorder) __target_intrinsic(glsl, "hitObjectTraceRayNV") static void __glslTraceRay( - out HitObject hitObj, + out HitObject hitObject, RaytracingAccelerationStructure accelerationStructure, uint rayFlags, uint cullMask, @@ -8947,7 +8935,7 @@ struct HitObject __glsl_extension(GL_NV_ray_tracing_motion_blur) __target_intrinsic(glsl, "hitObjectTraceRayMotionNV") static void __glslTraceMotionRay( - out HitObject hitObj, + out HitObject hitObject, RaytracingAccelerationStructure accelerationStructure, uint rayFlags, uint cullMask, diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 5603ef2a5..6699426d5 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -469,6 +469,16 @@ class ThisExpr: public Expr Scope* scope = nullptr; }; +// Represent a reference to the virtual __return_val object holding the return value of +// functions whose result type is non-copyable. +class ReturnValExpr : public Expr +{ + SLANG_AST_CLASS(ReturnValExpr) + + SLANG_UNREFLECTED + Scope* scope = nullptr; +}; + // An expression that binds a temporary variable in a local expression context class LetExpr: public Expr { diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index d38fd9374..dbf9bc8fc 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -220,6 +220,8 @@ struct ASTIterator void visitThisExpr(ThisExpr* expr) { iterator->maybeDispatchCallback(expr); } void visitThisTypeExpr(ThisTypeExpr* expr) { iterator->maybeDispatchCallback(expr); } + void visitReturnValExpr(ReturnValExpr* expr) { iterator->maybeDispatchCallback(expr); } + void visitAndTypeExpr(AndTypeExpr* expr) { iterator->maybeDispatchCallback(expr); diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 9704b4c87..f80de86fd 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -796,4 +796,14 @@ Type* removeParamDirType(Type* type) return type; } +bool isNonCopyableType(Type* type) +{ + auto declRefType = as<DeclRefType>(type); + if (!declRefType) + return false; + if (declRefType->getDeclRef().getDecl()->findModifier<NonCopyableTypeAttribute>()) + return true; + return false; +} + } // namespace Slang diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index c2a13542d..50d523cc5 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -834,5 +834,6 @@ class ModifiedType : public Type }; Type* removeParamDirType(Type* type); +bool isNonCopyableType(Type* type); } // namespace Slang diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a00b1bea2..055364d5e 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -3769,6 +3769,32 @@ namespace Slang return CreateErrorExpr(expr); } + Expr* SemanticsExprVisitor::visitReturnValExpr(ReturnValExpr* expr) + { + auto scope = expr->scope; + if (scope) + { + auto parentFunc = as<CallableDecl>(getParentFunc(scope->containerDecl)); + if (parentFunc) + { + if (as<ErrorType>(parentFunc->returnType.type)) + { + expr->type = parentFunc->returnType.type; + return expr; + } + if (isNonCopyableType(parentFunc->returnType.type)) + { + expr->type.isLeftValue = true; + expr->type.type = parentFunc->returnType.type; + return expr; + } + } + } + getSink()->diagnose(expr, Diagnostics::returnValNotAvailable); + expr->type = getASTBuilder()->getErrorType(); + return expr; + } + Expr* SemanticsExprVisitor::visitAndTypeExpr(AndTypeExpr* expr) { // The left and right sides of an `&` for types must both be types. diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index c4a7b3e6d..37dcba3f4 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2409,6 +2409,7 @@ namespace Slang Expr* visitThisExpr(ThisExpr* expr); Expr* visitThisTypeExpr(ThisTypeExpr* expr); + Expr* visitReturnValExpr(ReturnValExpr* expr); Expr* visitAndTypeExpr(AndTypeExpr* expr); Expr* visitPointerTypeExpr(PointerTypeExpr* expr); Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index b562ac880..7c8bab1ad 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -536,7 +536,7 @@ DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' doe DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") DIAGNOSTIC(38103, Error, thisTypeOutsideOfTypeDecl, "'This' type can only be used inside of an aggregate type") - +DIAGNOSTIC(38104, Error, returnValNotAvailable, "cannot use '__return_val' here. '__return_val' is defined only in functions that return a non-copyable value.") DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") DIAGNOSTIC(38021, Error, typeArgumentForGenericParameterDoesNotConformToInterface, "type argument `$0` for generic parameter `$1` does not conform to interface `$2`.") diff --git a/source/slang/slang-emit-c-like.cpp b/source/slang/slang-emit-c-like.cpp index fe25c3f19..a8080851f 100644 --- a/source/slang/slang-emit-c-like.cpp +++ b/source/slang/slang-emit-c-like.cpp @@ -1535,25 +1535,6 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst) if(as<IRUnconditionalBranch>(user)) return false; - // HACK: As a special case, an `allocateOpaqueHandle` operation should - // only be folded in if its only use is as the operand of a `store` - // that will *itself* get peephole merged in as the initial-value expression - // of a `var`: - // - if (inst->getOp() == kIROp_AllocateOpaqueHandle) - { - auto store = as<IRStore>(user); - if (!store) return false; - if (store->getVal() != inst) return false; - - auto var = as<IRVar>(store->getPtr()); - if (!var) return false; - - if(var->getNextInst() != store) return false; - - return true; - } - // Okay, if we reach this point then the user comes later in // the same block, and there are no instructions with side // effects in between, so it seems safe to fold things in. @@ -2003,7 +1984,6 @@ void CLikeSourceEmitter::defaultEmitInstExpr(IRInst* inst, const EmitOpInfo& inO case kIROp_undefined: case kIROp_DefaultConstruct: - case kIROp_AllocateOpaqueHandle: m_writer->emit(getName(inst)); break; @@ -2676,11 +2656,7 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst) break; case kIROp_AllocateOpaqueHandle: - { - _emitAllocateOpaqueHandleImpl(inst); - } break; - case kIROp_Var: { auto var = cast<IRVar>(inst); @@ -2875,11 +2851,6 @@ void CLikeSourceEmitter::_emitInstAsDefaultInitializedVar(IRInst* inst, IRType* m_writer->emit(";\n"); } -void CLikeSourceEmitter::_emitAllocateOpaqueHandleImpl(IRInst* allocateInst) -{ - _emitInstAsDefaultInitializedVar(allocateInst, allocateInst->getDataType()); -} - void CLikeSourceEmitter::emitSemanticsUsingVarLayout(IRVarLayout* varLayout) { if(auto semanticAttr = varLayout->findAttr<IRSemanticAttr>()) diff --git a/source/slang/slang-emit-c-like.h b/source/slang/slang-emit-c-like.h index 02ab28028..62f6d20a2 100644 --- a/source/slang/slang-emit-c-like.h +++ b/source/slang/slang-emit-c-like.h @@ -309,9 +309,8 @@ public: void emitStore(IRStore* store); virtual void _emitStoreImpl(IRStore* store); - virtual void _emitInstAsVarInitializerImpl(IRInst* inst); void _emitInstAsDefaultInitializedVar(IRInst* inst, IRType* type); - virtual void _emitAllocateOpaqueHandleImpl(IRInst* allocateInst); + void _emitInstAsVarInitializerImpl(IRInst* inst); UInt getBindingOffset(EmitVarChain* chain, LayoutResourceKind kind); UInt getBindingSpace(EmitVarChain* chain, LayoutResourceKind kind); diff --git a/source/slang/slang-emit-glsl.cpp b/source/slang/slang-emit-glsl.cpp index c651c8735..94c85409f 100644 --- a/source/slang/slang-emit-glsl.cpp +++ b/source/slang/slang-emit-glsl.cpp @@ -927,61 +927,6 @@ void GLSLSourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d } } -void GLSLSourceEmitter::_emitInstAsVarInitializerImpl(IRInst* inst) -{ - // Some opcodes can be folded into a variable initialization - // by allowing the variable to be "default-constructed." - // - switch (inst->getOp()) - { - case kIROp_AllocateOpaqueHandle: - // - // Note: semantically, we should only elide the initializer - // if `inst` is able to be folded here, since otherwise - // it could be a single allocation that is used to initialize - // multiple local variables (which should then alias the - // same location). - // - // However, since GlSL doesn't support assignment of opaque - // handle types, code will fail to compile downstream in - // the case where the initializer *doesn't* fold. - // - // The decision being made here should help ensure that we - // don't emit code that silently has different semantics - // than the input. - // - if (shouldFoldInstIntoUseSites(inst)) - { - return; - } - break; - - default: - break; - } - - // We fall back to the default behavior for all targets, - // which is to emit `inst` as an initial-value expression - // after an `=`. - // - Super::_emitInstAsVarInitializerImpl(inst); -} - -void GLSLSourceEmitter::_emitStoreImpl(IRStore* store) -{ - auto srcVal = store->getVal(); - switch (srcVal->getOp()) - { - default: - Super::_emitStoreImpl(store); - break; - - case kIROp_AllocateOpaqueHandle: - break; - } - -} - void GLSLSourceEmitter::_emitSpecialFloatImpl(IRType* type, const char* valueExpr) { if( type->getOp() != kIROp_FloatType ) @@ -2215,9 +2160,17 @@ void GLSLSourceEmitter::emitParamTypeImpl(IRType* type, String const& name) { if (auto refType = as<IRRefType>(type)) { - _requireGLSLExtension(UnownedStringSlice("GL_EXT_spirv_intrinsics")); - m_writer->emit("spirv_by_reference "); type = refType->getValueType(); + + if (as<IRRayQueryType>(type) || as<IRHitObjectType>(type)) + { + // GLSL will automatically pass these by reference, so we don't need to do anything. + } + else + { + _requireGLSLExtension(UnownedStringSlice("GL_EXT_spirv_intrinsics")); + m_writer->emit("spirv_by_reference "); + } } else if (auto spirvLiteralType = as<IRSPIRVLiteralType>(type)) { diff --git a/source/slang/slang-emit-glsl.h b/source/slang/slang-emit-glsl.h index 7c1a15315..780b24453 100644 --- a/source/slang/slang-emit-glsl.h +++ b/source/slang/slang-emit-glsl.h @@ -56,9 +56,6 @@ protected: virtual void emitSimpleValueImpl(IRInst* inst) SLANG_OVERRIDE; virtual void emitLoopControlDecorationImpl(IRLoopControlDecoration* decl) SLANG_OVERRIDE; - virtual void _emitInstAsVarInitializerImpl(IRInst* inst) SLANG_OVERRIDE; - virtual void _emitStoreImpl(IRStore* store) SLANG_OVERRIDE; - void _emitGLSLTextureOrTextureSamplerType(IRTextureTypeBase* type, char const* baseName); void _emitGLSLStructuredBuffer(IRGlobalParam* varDecl, IRHLSLStructuredBufferTypeBase* structuredBufferType); diff --git a/source/slang/slang-emit-hlsl.cpp b/source/slang/slang-emit-hlsl.cpp index 30de45773..66902a624 100644 --- a/source/slang/slang-emit-hlsl.cpp +++ b/source/slang/slang-emit-hlsl.cpp @@ -1136,31 +1136,6 @@ void HLSLSourceEmitter::_emitPrefixTypeAttr(IRAttr* attr) } } -void HLSLSourceEmitter::_emitInstAsVarInitializerImpl(IRInst* inst) -{ - // Some opcodes can be folded into a variable initialization - // by allowing the variable to be "default-constructed." - // - switch (inst->getOp()) - { - case kIROp_AllocateOpaqueHandle: - if (shouldFoldInstIntoUseSites(inst)) - { - return; - } - break; - - default: - break; - } - - // We fall back to the default behavior for all targets, - // which is to emit `inst` as an initial-value expression - // after an `=`. - // - Super::_emitInstAsVarInitializerImpl(inst); -} - void HLSLSourceEmitter::emitSimpleFuncParamImpl(IRParam* param) { emitRateQualifiers(param); diff --git a/source/slang/slang-emit-hlsl.h b/source/slang/slang-emit-hlsl.h index f2440fe38..08363bceb 100644 --- a/source/slang/slang-emit-hlsl.h +++ b/source/slang/slang-emit-hlsl.h @@ -64,8 +64,6 @@ protected: void _emitPrefixTypeAttr(IRAttr* attr) SLANG_OVERRIDE; - virtual void _emitInstAsVarInitializerImpl(IRInst* inst) SLANG_OVERRIDE; - // Emit a single `register` semantic, as appropriate for a given resource-type-specific layout info // Keyword to use in the uniform case (`register` for globals, `packoffset` inside a `cbuffer`) void _emitHLSLRegisterSemantic(LayoutResourceKind kind, EmitVarChain* chain, char const* uniformSemanticSpelling = "register"); diff --git a/source/slang/slang-ir-specialize-resources.cpp b/source/slang/slang-ir-specialize-resources.cpp index 620c4e508..87d00a32b 100644 --- a/source/slang/slang-ir-specialize-resources.cpp +++ b/source/slang/slang-ir-specialize-resources.cpp @@ -355,6 +355,12 @@ struct ResourceOutputSpecializationPass if(as<IRSamplerStateTypeBase>(type)) return true; + if (as<IRRayQueryType>(type)) + return true; + + if (as<IRHitObjectType>(type)) + return true; + // TODO: more cases here? return false; diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index bbd494ee9..ed1da3d25 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -563,6 +563,10 @@ struct IRGenContext // The IR witness value to use for `ThisType` IRInst* thisTypeWitness = nullptr; + // The return destination parameter to write to at return sites. + // (For use by functions that returns non-copyable types) + LoweredValInfo returnDestination; + bool includeDebugInfo = false; explicit IRGenContext(SharedIRGenContext* inShared, ASTBuilder* inAstBuilder) @@ -804,6 +808,11 @@ LoweredValInfo lowerRValueExpr( IRGenContext* context, Expr* expr); +void lowerRValueExprWithDestination( + IRGenContext* context, + LoweredValInfo destination, + Expr* expr); + IRType* lowerType( IRGenContext* context, Type* type); @@ -1038,8 +1047,6 @@ LoweredValInfo extractField( } } - - LoweredValInfo materialize( IRGenContext* context, LoweredValInfo lowered) @@ -1232,6 +1239,12 @@ void assign( LoweredValInfo const& left, LoweredValInfo const& right); +void assignExpr( + IRGenContext* context, + const LoweredValInfo& inLeft, + Expr* rightExpr, + SourceLoc assignmentLoc); + IRInst* getAddress( IRGenContext* context, LoweredValInfo const& inVal, @@ -2703,6 +2716,9 @@ struct IRLoweringParameterInfo // Is this the representation of a `this` parameter? bool isThisParam = false; + + // Is this the destination of address for non-copyable return val? + bool isReturnDestination = false; }; // // We need a way to be able to create a `IRLoweringParameterInfo` given the declaration @@ -2772,6 +2788,19 @@ void addThisParameter( ioParameterLists->params.add(info); } + +void maybeAddReturnDestinationParam(ParameterLists* ioParameterLists, Type* resultType) +{ + if (isNonCopyableType(resultType)) + { + IRLoweringParameterInfo info; + info.type = resultType; + info.decl = nullptr; + info.direction = kParameterDirection_Ref; + info.isReturnDestination = true; + ioParameterLists->params.add(info); + } +} // // And here is our function that will do the recursive walk: void collectParameterLists( @@ -2842,6 +2871,7 @@ void collectParameterLists( { ioParameterLists->params.add(getParameterInfo(context, paramDeclRef)); } + maybeAddReturnDestinationParam(ioParameterLists, getResultType(context->astBuilder, callableDeclRef)); } } } @@ -2883,6 +2913,10 @@ struct FuncDeclBaseTypeInfo IRType* resultType; ParameterLists parameterLists; List<IRType*> paramTypes; + // If the function returns a non-copyable value, this + // flag is set to indicate that the result should be + // returned via the last ref parameter. + bool returnViaLastRefParam = false; }; void _lowerFuncDeclBaseTypeInfo( @@ -2947,24 +2981,34 @@ void _lowerFuncDeclBaseTypeInfo( } auto& irResultType = outInfo.resultType; - irResultType = lowerType(context, getResultType(context->astBuilder, declRef)); - - if (auto setterDeclRef = declRef.as<SetterDecl>()) + + if (parameterLists.params.getCount() && parameterLists.params.getLast().isReturnDestination) { - // A `set` accessor always returns `void` - // - // TODO: We should handle this by making the result - // type of a `set` accessor be represented accurately - // at the AST level (ditto for the `ref` case below). - // - irResultType = builder->getVoidType(); + irResultType = context->irBuilder->getVoidType(); + outInfo.returnViaLastRefParam = true; } - - if( auto refAccessorDeclRef = declRef.as<RefAccessorDecl>() ) + else { - // A `ref` accessor needs to return a *pointer* to the value - // being accessed, rather than a simple value. - irResultType = builder->getPtrType(irResultType); + irResultType = lowerType(context, getResultType(context->astBuilder, declRef)); + + + if (auto setterDeclRef = declRef.as<SetterDecl>()) + { + // A `set` accessor always returns `void` + // + // TODO: We should handle this by making the result + // type of a `set` accessor be represented accurately + // at the AST level (ditto for the `ref` case below). + // + irResultType = builder->getVoidType(); + } + + if (auto refAccessorDeclRef = declRef.as<RefAccessorDecl>()) + { + // A `ref` accessor needs to return a *pointer* to the value + // being accessed, rather than a simple value. + irResultType = builder->getPtrType(irResultType); + } } if (!getErrorCodeType(context->astBuilder, declRef)->equals(context->astBuilder->getBottomType())) @@ -3023,10 +3067,8 @@ static LoweredValInfo _emitCallToAccessor( return result; } -// - template<typename Derived> -struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> +struct ExprLoweringContext { static bool isLValueContext() { return Derived::_isLValueContext(); } @@ -3035,20 +3077,493 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> IRBuilder* getBuilder() { return context->irBuilder; } ASTBuilder* getASTBuilder() { return context->astBuilder; } + + struct ResolvedCallInfo + { + DeclRef<Decl> funcDeclRef; + Expr* baseExpr = nullptr; + }; + + // Try to resolve a the function expression for a call + // into a reference to a specific declaration, along + // with some contextual information about the declaration + // we are calling. + bool tryResolveDeclRefForCall( + Expr* funcExpr, + ResolvedCallInfo* outInfo) + { + // TODO: unwrap any "identity" expressions that might + // be wrapping the callee. + + // First look to see if the expression references a + // declaration at all. + auto declRefExpr = as<DeclRefExpr>(funcExpr); + if (!declRefExpr) + return false; + + // A little bit of future proofing here: if we ever + // allow higher-order functions, then we might be + // calling through a variable/field that has a function + // type, but is not itself a function. + // In such a case we should be careful to not statically + // resolve things. + // + if (auto callableDecl = as<CallableDecl>(declRefExpr->declRef.getDecl())) + { + // Okay, the declaration is directly callable, so we can continue. + } + else + { + // The callee declaration isn't itself a callable (it must have + // a function type, though). + return false; + } + + // Now we can look at the specific kinds of declaration references, + // and try to tease them apart. + if (auto memberFuncExpr = as<MemberExpr>(funcExpr)) + { + outInfo->funcDeclRef = memberFuncExpr->declRef; + outInfo->baseExpr = memberFuncExpr->baseExpression; + return true; + } + else if (auto staticMemberFuncExpr = as<StaticMemberExpr>(funcExpr)) + { + outInfo->funcDeclRef = staticMemberFuncExpr->declRef; + return true; + } + else if (auto varExpr = as<VarExpr>(funcExpr)) + { + outInfo->funcDeclRef = varExpr->declRef; + return true; + } + else + { + // Seems to be a case of declaration-reference we don't know about. + SLANG_UNEXPECTED("unknown declaration reference kind"); + //return false; + } + } + + /// Return `expr` with any outer casts to interface types stripped away + Expr* maybeIgnoreCastToInterface(Expr* expr) + { + auto e = expr; + while (auto castExpr = as<CastToSuperTypeExpr>(e)) + { + if (auto declRefType = as<DeclRefType>(e->type)) + { + if (declRefType->getDeclRef().as<InterfaceDecl>()) + { + e = castExpr->valueArg; + continue; + } + } + else if (auto andType = as<AndType>(e->type)) + { + // TODO: We might eventually need to tell the difference + // between conjunctions of interfaces and conjunctions + // that might include non-interface types. + // + // For now we assume that any case to a conjunction + // is effectively a cast to an interface type. + // + e = castExpr->valueArg; + continue; + } + break; + } + return e; + } + + // Lower an expression that should have the same l-value-ness // as the visitor itself. LoweredValInfo lowerSubExpr(Expr* expr) { IRBuilderSourceLocRAII sourceLocInfo(getBuilder(), expr->loc); - return this->dispatch(expr); + if (isLValueContext()) + return lowerLValueExpr(context, expr); + return lowerRValueExpr(context, expr); } - LoweredValInfo lowerSubExpr(Expr* expr, IRGenContext* subContext) + /// Create IR instructions for an argument at a call site, based on + /// AST-level expressions plus function signature information. + /// + /// The `funcType` parameter is always required, and specifies the types + /// of all the parameters. The `funcDeclRef` parameter is only required + /// if there are parameter positions for which the matching argument is + /// absent. + /// + void addDirectCallArgs( + InvokeExpr* expr, + Index argIndex, + IRType* paramType, + ParameterDirection paramDirection, + DeclRef<ParamDecl> paramDeclRef, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) { - IRBuilderSourceLocRAII sourceLocInfo(getBuilder(), expr->loc); - Derived d; - d.context = subContext; - return d.dispatch(expr); + Count argCount = expr->arguments.getCount(); + if (argIndex < argCount) + { + auto argExpr = expr->arguments[argIndex]; + addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); + } + else + { + // We have run out of arguments supplied at the call site, + // but there are still parameters remaining. This must mean + // that these parameters have default argument expressions + // associated with them. + // + // Currently we simply extract the initial-value expression + // from the parameter declaration and then lower it in + // the context of the caller. + // + // Note that the expression could involve subsitutions because + // in the general case it could depend on the generic parameters + // used the specialize the callee. For now we do not handle that + // case, and simply ignore generic arguments. + // + SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); + SLANG_ASSERT(argExpr); + + IRGenEnv subEnvStorage; + IRGenEnv* subEnv = &subEnvStorage; + subEnv->outer = context->env; + + IRGenContext subContextStorage = *context; + IRGenContext* subContext = &subContextStorage; + subContext->env = subEnv; + + _lowerSubstitutionEnv(subContext, argExpr.getSubsts() ? argExpr.getSubsts().declRef : nullptr); + + addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); + + // TODO: The approach we are taking here to default arguments + // is simplistic, and has consequences for the front-end as + // well as binary serialization of modules. + // + // We could consider some more refined approaches where, e.g., + // functions with default arguments generate multiple IR-level + // functions, that compute and provide the default values. + // + // Alternatively, each parameter with defaults could be generated + // into its own callable function that provides the default value, + // so that calling modules can call into a pre-generated function. + // + // Each of these options involves trade-offs, and we need to + // make a conscious decision at some point. + + // Assert that such an expression must have been present. + } + } + + void addDirectCallArgs( + InvokeExpr* expr, + FuncType* funcType, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + Count argCount = expr->arguments.getCount(); + SLANG_ASSERT(argCount == funcType->getParamCount()); + + for (Index i = 0; i < argCount; ++i) + { + IRType* paramType = lowerType(context, funcType->getParamType(i)); + ParameterDirection paramDirection = funcType->getParamDirection(i); + addDirectCallArgs(expr, i, paramType, paramDirection, DeclRef<ParamDecl>(), ioArgs, ioFixups); + } + } + + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<CallableDecl> funcDeclRef, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + Count argCounter = 0; + for (auto paramDeclRef : getMembersOfType<ParamDecl>(getASTBuilder(), funcDeclRef)) + { + auto paramDecl = paramDeclRef.getDecl(); + IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef)); + auto paramDirection = getParameterDirection(paramDecl); + + Index argIndex = argCounter++; + addDirectCallArgs(expr, argIndex, paramType, paramDirection, paramDeclRef, ioArgs, ioFixups); + } + } + + // Add arguments that appeared directly in an argument list + // to the list of argument values for a call. + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<Decl> funcDeclRef, + List<IRInst*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + if (auto callableDeclRef = funcDeclRef.as<CallableDecl>()) + { + addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); + } + else + { + SLANG_UNEXPECTED("callee was not a callable decl"); + } + } + + void addFuncBaseArgs( + LoweredValInfo funcVal, + List<IRInst*>* /*ioArgs*/) + { + switch (funcVal.flavor) + { + default: + return; + } + } + + + void _lowerSubstitutionArg(IRGenContext* subContext, GenericAppDeclRef* subst, Decl* paramDecl, Index argIndex) + { + SLANG_ASSERT(argIndex < subst->getArgs().getCount()); + auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]); + subContext->setValue(paramDecl, argVal); + } + + void _lowerSubstitutionEnv(IRGenContext* subContext, DeclRefBase* subst) + { + if (!subst) return; + _lowerSubstitutionEnv(subContext, subst->getBase()); + + if (auto genSubst = as<GenericAppDeclRef>(subst)) + { + auto genDecl = genSubst->getGenericDecl(); + + Index argCounter = 0; + for (auto memberDecl : genDecl->members) + { + if (auto typeParamDecl = as<GenericTypeParamDecl>(memberDecl)) + { + _lowerSubstitutionArg(subContext, genSubst, typeParamDecl, argCounter++); + } + else if (auto valParamDecl = as<GenericValueParamDecl>(memberDecl)) + { + _lowerSubstitutionArg(subContext, genSubst, valParamDecl, argCounter++); + } + } + for (auto memberDecl : genDecl->members) + { + if (auto constraintDecl = as<GenericTypeConstraintDecl>(memberDecl)) + { + _lowerSubstitutionArg(subContext, genSubst, constraintDecl, argCounter++); + } + } + } + // TODO: also need to handle this-type substitution here? + } + + /// Lower an invoke expr, and attempt to fuse a store of the expr's result into destination. + /// If the store is fused, returns LoweredValInfo::None. Otherwise, returns the IR val representing the RValue. + LoweredValInfo visitInvokeExprImpl(InvokeExpr* expr, LoweredValInfo destination, const TryClauseEnvironment& tryEnv) + { + auto type = lowerType(context, expr->type); + + // We are going to look at the syntactic form of + // the "function" expression, so that we can avoid + // a lot of complexity that would come from lowering + // it as a general expression first, and then trying + // to apply it. For example, given `obj.f(a,b)` we + // will try to detect that we are trying to compute + // something like `ObjType::f(obj, a, b)` (in pseudo-code), + // rather than trying to construct a meaningful + // intermediate value for `obj.f` first. + // + // Note that this doe not preclude having support + // for directly generating code from `obj.f` - it + // just may be that such usage is more complicated. + + // Along the way, we may end up collecting additional + // arguments that will be part of the call. + List<IRInst*> irArgs; + + // We will also collect "fixup" actions that need + // to be performed after the call, in order to + // copy the final values for `out` parameters + // back to their arguments. + List<OutArgumentFixup> argFixups; + + auto funcExpr = expr->functionExpr; + ResolvedCallInfo resolvedInfo; + if (tryResolveDeclRefForCall(funcExpr, &resolvedInfo)) + { + // In this case we know exactly what declaration we + // are going to call, and so we can resolve things + // appropriately. + auto funcDeclRef = resolvedInfo.funcDeclRef; + auto baseExpr = resolvedInfo.baseExpr; + + // If the thing being invoked is a subscript operation, + // then we need to handle multiple extra details + // that don't arise for other kinds of calls. + // + // TODO: subscript operations probably deserve to + // be handled on their own path for this reason... + // + if (auto subscriptDeclRef = funcDeclRef.template as<SubscriptDecl>()) + { + // A reference to a subscript declaration is a special case, + // because it is not possible to call a subscript directly; + // we must call one of its accessors. + // + auto loweredBase = lowerSubExpr(baseExpr); + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + auto result = lowerStorageReference(context, type, subscriptDeclRef, loweredBase, irArgs.getCount(), irArgs.getBuffer()); + + // TODO: Applying the fixups for arguments to the subscript at this point + // won't technically be correct, since the call to the subscript may + // not have occured at this point. + // + // It seems like we need to either: + // + // * Capture the arguments to the subscript as `LoweredValInfo` instead of `IRInst*` + // so that we can deal with everything related to fixups around the actual call + // site. + // + // OR + // + // * Handle everything to do with "fixups" differently, by treating them as deferred + // actions that gert queued up on the context itself and then flushed at certain + // well-defined points, so that we don't have to be as careful around them. + // + // OR + // + // * Switch to a more "destination-driven" approach to code generation, where we + // can determine on entry to the lowering of a sub-expression whether it will be + // used for read, write, or read/write, and resolve things like the choice of + // accessor at that point instead. + // + applyOutArgumentFixups(context, argFixups); + return result; + } + + // First comes the `this` argument if we are calling + // a member function: + if (baseExpr) + { + // The base expression might be an "upcast" to a base interface, in + // which case we don't want to emit the result of the cast, but instead + // the source. + // + baseExpr = this->maybeIgnoreCastToInterface(baseExpr); + + auto thisType = getThisParamTypeForCallable(context, funcDeclRef); + auto irThisType = lowerType(context, thisType); + addCallArgsForParam( + context, + irThisType, + getThisParamDirection(funcDeclRef.getDecl(), kParameterDirection_In), + baseExpr, + &irArgs, + &argFixups); + } + + // Then we have the "direct" arguments to the call. + // These may include `out` and `inout` arguments that + // require "fixup" work on the other side. + // + FuncDeclBaseTypeInfo funcTypeInfo; + _lowerFuncDeclBaseTypeInfo(context, funcDeclRef.template as<FunctionDeclBase>(), funcTypeInfo); + + auto funcType = funcTypeInfo.type; + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + + LoweredValInfo result; + if (funcTypeInfo.returnViaLastRefParam) + { + // If the function returns a non-copyable type, then we need to + // pass in the destination that receives the result value as an `__ref` parameter. + // + if (destination.flavor != LoweredValInfo::Flavor::None) + { + // If we have a known destination, we can use it directly as argument to the call. + irArgs.add(destination.val); + result = LoweredValInfo(); + } + else + { + // Otherwise, we need to create a temporary variable to hold the result. + // + auto tempVar = context->irBuilder->emitVar(tryGetPointedToType(context->irBuilder, funcTypeInfo.paramTypes.getLast())); + irArgs.add(tempVar); + result = LoweredValInfo::ptr(tempVar); + } + } + + auto callResult = emitCallToDeclRef( + context, + type, + funcDeclRef, + funcType, + irArgs, + tryEnv); + applyOutArgumentFixups(context, argFixups); + + if (funcTypeInfo.returnViaLastRefParam) + return result; + return callResult; + } + else if (auto funcType = as<FuncType>(expr->functionExpr->type)) + { + auto funcVal = lowerRValueExpr(context, expr->functionExpr); + addDirectCallArgs(expr, funcType, &irArgs, &argFixups); + + auto result = emitCallToVal(context, type, funcVal, irArgs.getCount(), irArgs.getBuffer(), tryEnv); + + applyOutArgumentFixups(context, argFixups); + return result; + } + + + // TODO: In this case we should be emitting code for the callee as + // an ordinary expression, then emitting the arguments according + // to the type information on the callee (e.g., which parameters + // are `out` or `inout`, and then finally emitting the `call` + // instruction. + // + // We don't currently have the case of emitting arguments according + // to function type info (instead of declaration info), and really + // this case can't occur unless we start adding first-class functions + // to the source language. + // + // For now we just bail out with an error. + // + SLANG_UNEXPECTED("could not resolve target declaration for call"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + +}; + +template<typename Derived> +struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo> +{ + static bool isLValueContext() { return Derived::_isLValueContext(); } + + ExprLoweringContext<Derived> sharedLoweringContext; + + IRGenContext*& context; + + ExprLoweringVisitorBase() + : context(sharedLoweringContext.context) + { + } + + IRBuilder* getBuilder() { return context->irBuilder; } + ASTBuilder* getASTBuilder() { return context->astBuilder; } + LoweredValInfo lowerSubExpr(Expr* expr) + { + return sharedLoweringContext.lowerSubExpr(expr); } LoweredValInfo visitIncompleteExpr(IncompleteExpr*) @@ -3380,12 +3895,17 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return context->thisVal; } + LoweredValInfo visitReturnValExpr(ReturnValExpr*) + { + return context->returnDestination; + } + LoweredValInfo visitMemberExpr(MemberExpr* expr) { auto loweredType = lowerType(context, expr->type); auto baseExpr = expr->baseExpression; - baseExpr = maybeIgnoreCastToInterface(baseExpr); + baseExpr = sharedLoweringContext.maybeIgnoreCastToInterface(baseExpr); auto loweredBase = lowerSubExpr(baseExpr); auto declRef = expr->declRef; @@ -3812,281 +4332,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } - void _lowerSubstitutionArg(IRGenContext* subContext, GenericAppDeclRef* subst, Decl* paramDecl, Index argIndex) - { - SLANG_ASSERT(argIndex < subst->getArgs().getCount()); - auto argVal = lowerVal(subContext, subst->getArgs()[argIndex]); - subContext->setValue(paramDecl, argVal); - } - - void _lowerSubstitutionEnv(IRGenContext* subContext, DeclRefBase* subst) - { - if(!subst) return; - _lowerSubstitutionEnv(subContext, subst->getBase()); - - if (auto genSubst = as<GenericAppDeclRef>(subst)) - { - auto genDecl = genSubst->getGenericDecl(); - - Index argCounter = 0; - for( auto memberDecl: genDecl->members ) - { - if(auto typeParamDecl = as<GenericTypeParamDecl>(memberDecl) ) - { - _lowerSubstitutionArg(subContext, genSubst, typeParamDecl, argCounter++); - } - else if( auto valParamDecl = as<GenericValueParamDecl>(memberDecl) ) - { - _lowerSubstitutionArg(subContext, genSubst, valParamDecl, argCounter++); - } - } - for( auto memberDecl: genDecl->members ) - { - if(auto constraintDecl = as<GenericTypeConstraintDecl>(memberDecl) ) - { - _lowerSubstitutionArg(subContext, genSubst, constraintDecl, argCounter++); - } - } - } - // TODO: also need to handle this-type substitution here? - } - - /// Create IR instructions for an argument at a call site, based on - /// AST-level expressions plus function signature information. - /// - /// The `funcType` parameter is always required, and specifies the types - /// of all the parameters. The `funcDeclRef` parameter is only required - /// if there are parameter positions for which the matching argument is - /// absent. - /// - void addDirectCallArgs( - InvokeExpr* expr, - Index argIndex, - IRType* paramType, - ParameterDirection paramDirection, - DeclRef<ParamDecl> paramDeclRef, - List<IRInst*>* ioArgs, - List<OutArgumentFixup>* ioFixups) - { - Count argCount = expr->arguments.getCount(); - if (argIndex < argCount) - { - auto argExpr = expr->arguments[argIndex]; - addCallArgsForParam(context, paramType, paramDirection, argExpr, ioArgs, ioFixups); - } - else - { - // We have run out of arguments supplied at the call site, - // but there are still parameters remaining. This must mean - // that these parameters have default argument expressions - // associated with them. - // - // Currently we simply extract the initial-value expression - // from the parameter declaration and then lower it in - // the context of the caller. - // - // Note that the expression could involve subsitutions because - // in the general case it could depend on the generic parameters - // used the specialize the callee. For now we do not handle that - // case, and simply ignore generic arguments. - // - SubstExpr<Expr> argExpr = getInitExpr(getASTBuilder(), paramDeclRef); - SLANG_ASSERT(argExpr); - - IRGenEnv subEnvStorage; - IRGenEnv* subEnv = &subEnvStorage; - subEnv->outer = context->env; - - IRGenContext subContextStorage = *context; - IRGenContext* subContext = &subContextStorage; - subContext->env = subEnv; - - _lowerSubstitutionEnv(subContext, argExpr.getSubsts() ? argExpr.getSubsts().declRef : nullptr); - - addCallArgsForParam(subContext, paramType, paramDirection, argExpr.getExpr(), ioArgs, ioFixups); - - // TODO: The approach we are taking here to default arguments - // is simplistic, and has consequences for the front-end as - // well as binary serialization of modules. - // - // We could consider some more refined approaches where, e.g., - // functions with default arguments generate multiple IR-level - // functions, that compute and provide the default values. - // - // Alternatively, each parameter with defaults could be generated - // into its own callable function that provides the default value, - // so that calling modules can call into a pre-generated function. - // - // Each of these options involves trade-offs, and we need to - // make a conscious decision at some point. - - // Assert that such an expression must have been present. - } - } - - void addDirectCallArgs( - InvokeExpr* expr, - FuncType* funcType, - List<IRInst*>* ioArgs, - List<OutArgumentFixup>* ioFixups) - { - Count argCount = expr->arguments.getCount(); - SLANG_ASSERT(argCount == funcType->getParamCount()); - - for(Index i = 0; i < argCount; ++i) - { - IRType* paramType = lowerType(context, funcType->getParamType(i)); - ParameterDirection paramDirection = funcType->getParamDirection(i); - addDirectCallArgs(expr, i, paramType, paramDirection, DeclRef<ParamDecl>(), ioArgs, ioFixups); - } - } - - - void addDirectCallArgs( - InvokeExpr* expr, - DeclRef<CallableDecl> funcDeclRef, - List<IRInst*>* ioArgs, - List<OutArgumentFixup>* ioFixups) - { - Count argCounter = 0; - for (auto paramDeclRef : getMembersOfType<ParamDecl>(getASTBuilder(), funcDeclRef)) - { - auto paramDecl = paramDeclRef.getDecl(); - IRType* paramType = lowerType(context, getType(getASTBuilder(), paramDeclRef)); - auto paramDirection = getParameterDirection(paramDecl); - - Index argIndex = argCounter++; - addDirectCallArgs(expr, argIndex, paramType, paramDirection, paramDeclRef, ioArgs, ioFixups); - } - } - - // Add arguments that appeared directly in an argument list - // to the list of argument values for a call. - void addDirectCallArgs( - InvokeExpr* expr, - DeclRef<Decl> funcDeclRef, - List<IRInst*>* ioArgs, - List<OutArgumentFixup>* ioFixups) - { - if (auto callableDeclRef = funcDeclRef.as<CallableDecl>()) - { - addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); - } - else - { - SLANG_UNEXPECTED("callee was not a callable decl"); - } - } - - void addFuncBaseArgs( - LoweredValInfo funcVal, - List<IRInst*>* /*ioArgs*/) - { - switch (funcVal.flavor) - { - default: - return; - } - } - - struct ResolvedCallInfo - { - DeclRef<Decl> funcDeclRef; - Expr* baseExpr = nullptr; - }; - - // Try to resolve a the function expression for a call - // into a reference to a specific declaration, along - // with some contextual information about the declaration - // we are calling. - bool tryResolveDeclRefForCall( - Expr* funcExpr, - ResolvedCallInfo* outInfo) - { - // TODO: unwrap any "identity" expressions that might - // be wrapping the callee. - - // First look to see if the expression references a - // declaration at all. - auto declRefExpr = as<DeclRefExpr>(funcExpr); - if(!declRefExpr) - return false; - - // A little bit of future proofing here: if we ever - // allow higher-order functions, then we might be - // calling through a variable/field that has a function - // type, but is not itself a function. - // In such a case we should be careful to not statically - // resolve things. - // - if(auto callableDecl = as<CallableDecl>(declRefExpr->declRef.getDecl())) - { - // Okay, the declaration is directly callable, so we can continue. - } - else - { - // The callee declaration isn't itself a callable (it must have - // a function type, though). - return false; - } - - // Now we can look at the specific kinds of declaration references, - // and try to tease them apart. - if (auto memberFuncExpr = as<MemberExpr>(funcExpr)) - { - outInfo->funcDeclRef = memberFuncExpr->declRef; - outInfo->baseExpr = memberFuncExpr->baseExpression; - return true; - } - else if (auto staticMemberFuncExpr = as<StaticMemberExpr>(funcExpr)) - { - outInfo->funcDeclRef = staticMemberFuncExpr->declRef; - return true; - } - else if (auto varExpr = as<VarExpr>(funcExpr)) - { - outInfo->funcDeclRef = varExpr->declRef; - return true; - } - else - { - // Seems to be a case of declaration-reference we don't know about. - SLANG_UNEXPECTED("unknown declaration reference kind"); - //return false; - } - } - - /// Return `expr` with any outer casts to interface types stripped away - Expr* maybeIgnoreCastToInterface(Expr* expr) - { - auto e = expr; - while( auto castExpr = as<CastToSuperTypeExpr>(e) ) - { - if(auto declRefType = as<DeclRefType>(e->type)) - { - if(declRefType->getDeclRef().as<InterfaceDecl>()) - { - e = castExpr->valueArg; - continue; - } - } - else if( auto andType = as<AndType>(e->type) ) - { - // TODO: We might eventually need to tell the difference - // between conjunctions of interfaces and conjunctions - // that might include non-interface types. - // - // For now we assume that any case to a conjunction - // is effectively a cast to an interface type. - // - e = castExpr->valueArg; - continue; - } - break; - } - return e; - } - LoweredValInfo visitSelectExpr(SelectExpr* expr) { // A vector typed `select` expr will turn into a normal `select` op. @@ -4126,158 +4371,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo visitInvokeExpr(InvokeExpr* expr) { - return visitInvokeExprImpl(expr, TryClauseEnvironment()); - } - - LoweredValInfo visitInvokeExprImpl(InvokeExpr* expr, const TryClauseEnvironment& tryEnv) - { - auto type = lowerType(context, expr->type); - - // We are going to look at the syntactic form of - // the "function" expression, so that we can avoid - // a lot of complexity that would come from lowering - // it as a general expression first, and then trying - // to apply it. For example, given `obj.f(a,b)` we - // will try to detect that we are trying to compute - // something like `ObjType::f(obj, a, b)` (in pseudo-code), - // rather than trying to construct a meaningful - // intermediate value for `obj.f` first. - // - // Note that this doe not preclude having support - // for directly generating code from `obj.f` - it - // just may be that such usage is more complicated. - - // Along the way, we may end up collecting additional - // arguments that will be part of the call. - List<IRInst*> irArgs; - - // We will also collect "fixup" actions that need - // to be performed after the call, in order to - // copy the final values for `out` parameters - // back to their arguments. - List<OutArgumentFixup> argFixups; - - auto funcExpr = expr->functionExpr; - ResolvedCallInfo resolvedInfo; - if (tryResolveDeclRefForCall(funcExpr, &resolvedInfo)) - { - // In this case we know exactly what declaration we - // are going to call, and so we can resolve things - // appropriately. - auto funcDeclRef = resolvedInfo.funcDeclRef; - auto baseExpr = resolvedInfo.baseExpr; - - // If the thing being invoked is a subscript operation, - // then we need to handle multiple extra details - // that don't arise for other kinds of calls. - // - // TODO: subscript operations probably deserve to - // be handled on their own path for this reason... - // - if (auto subscriptDeclRef = funcDeclRef.template as<SubscriptDecl>()) - { - // A reference to a subscript declaration is a special case, - // because it is not possible to call a subscript directly; - // we must call one of its accessors. - // - auto loweredBase = lowerSubExpr(baseExpr); - addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); - auto result = lowerStorageReference(context, type, subscriptDeclRef, loweredBase, irArgs.getCount(), irArgs.getBuffer()); - - // TODO: Applying the fixups for arguments to the subscript at this point - // won't technically be correct, since the call to the subscript may - // not have occured at this point. - // - // It seems like we need to either: - // - // * Capture the arguments to the subscript as `LoweredValInfo` instead of `IRInst*` - // so that we can deal with everything related to fixups around the actual call - // site. - // - // OR - // - // * Handle everything to do with "fixups" differently, by treating them as deferred - // actions that gert queued up on the context itself and then flushed at certain - // well-defined points, so that we don't have to be as careful around them. - // - // OR - // - // * Switch to a more "destination-driven" approach to code generation, where we - // can determine on entry to the lowering of a sub-expression whether it will be - // used for read, write, or read/write, and resolve things like the choice of - // accessor at that point instead. - // - applyOutArgumentFixups(context, argFixups); - return result; - } - - // First comes the `this` argument if we are calling - // a member function: - if (baseExpr) - { - // The base expression might be an "upcast" to a base interface, in - // which case we don't want to emit the result of the cast, but instead - // the source. - // - baseExpr = maybeIgnoreCastToInterface(baseExpr); - - auto thisType = getThisParamTypeForCallable(context, funcDeclRef); - auto irThisType = lowerType(context, thisType); - addCallArgsForParam( - context, - irThisType, - getThisParamDirection(funcDeclRef.getDecl(), kParameterDirection_In), - baseExpr, - &irArgs, - &argFixups); - } - - // Then we have the "direct" arguments to the call. - // These may include `out` and `inout` arguments that - // require "fixup" work on the other side. - // - FuncDeclBaseTypeInfo funcTypeInfo; - _lowerFuncDeclBaseTypeInfo(context, funcDeclRef.template as<FunctionDeclBase>(), funcTypeInfo); - - auto funcType = funcTypeInfo.type; - addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); - auto result = emitCallToDeclRef( - context, - type, - funcDeclRef, - funcType, - irArgs, - tryEnv); - applyOutArgumentFixups(context, argFixups); - return result; - } - else if(auto funcType = as<FuncType>(expr->functionExpr->type)) - { - auto funcVal = lowerRValueExpr(context, expr->functionExpr); - addDirectCallArgs(expr, funcType, &irArgs, &argFixups); - - auto result = emitCallToVal(context, type, funcVal, irArgs.getCount(), irArgs.getBuffer(), tryEnv); - - applyOutArgumentFixups(context, argFixups); - return result; - } - - - // TODO: In this case we should be emitting code for the callee as - // an ordinary expression, then emitting the arguments according - // to the type information on the callee (e.g., which parameters - // are `out` or `inout`, and then finally emitting the `call` - // instruction. - // - // We don't currently have the case of emitting arguments according - // to function type info (instead of declaration info), and really - // this case can't occur unless we start adding first-class functions - // to the source language. - // - // For now we just bail out with an error. - // - SLANG_UNEXPECTED("could not resolve target declaration for call"); - UNREACHABLE_RETURN(LoweredValInfo()); + return sharedLoweringContext.visitInvokeExprImpl(expr, LoweredValInfo(), TryClauseEnvironment()); } /// Emit code for a `try` invoke. @@ -4287,7 +4381,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> assert(invokeExpr); TryClauseEnvironment tryEnv; tryEnv.clauseType = expr->tryClauseType; - return visitInvokeExprImpl(invokeExpr, tryEnv); + return sharedLoweringContext.visitInvokeExprImpl(invokeExpr, LoweredValInfo(), tryEnv); } /// Emit code to cast `value` to a concrete `superType` (e.g., a `struct`). @@ -4540,8 +4634,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> // based on the resulting values. // auto leftVal = lowerLValueExpr(context, expr->left); - auto rightVal = lowerRValueExpr(context, expr->right); - assign(context, leftVal, rightVal); + assignExpr(context, leftVal, expr->right, expr->loc); // The result value of the assignment expression is // the value of the left-hand side (and it is expected @@ -4780,7 +4873,7 @@ struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVis } }; -struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor> +struct RValueExprLoweringVisitor : public ExprLoweringVisitorBase<RValueExprLoweringVisitor> { static bool _isLValueContext() { return false; } @@ -4868,6 +4961,55 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVis } }; +// ExprLoweringVisitor that fuses the destination assignment. +// +struct DestinationDrivenRValueExprLoweringVisitor + : ExprVisitor<DestinationDrivenRValueExprLoweringVisitor> +{ + ExprLoweringContext<DestinationDrivenRValueExprLoweringVisitor> sharedLoweringContext; + LoweredValInfo destination; + + IRGenContext*& context; + DestinationDrivenRValueExprLoweringVisitor() + : context(sharedLoweringContext.context) + {} + + static bool _isLValueContext() { return false; } + + // The default case is lower the rvalue expr independently and then assign to destination. + void visitExpr(Expr* expr) + { + auto rValue = lowerRValueExpr(context, expr); + assign(context, destination, rValue); + } + + void visitInvokeExpr(InvokeExpr* expr) + { + LoweredValInfo resultRVal; + { + IRBuilderSourceLocRAII sourceLocInfo(context->irBuilder, expr->loc); + resultRVal = sharedLoweringContext.visitInvokeExprImpl(expr, destination, TryClauseEnvironment{}); + } + if (resultRVal.flavor != LoweredValInfo::Flavor::None) + { + // If we weren't able to fuse the destination write during lowering rvalue, + // we should insert the assign operation now. + assign(context, destination, resultRVal); + } + } + + /// Emit code for a `try` invoke. + LoweredValInfo visitTryExpr(TryExpr* expr) + { + auto invokeExpr = as<InvokeExpr>(expr->base); + assert(invokeExpr); + TryClauseEnvironment tryEnv; + tryEnv.clauseType = expr->tryClauseType; + return sharedLoweringContext.visitInvokeExprImpl(invokeExpr, destination, tryEnv); + } + +}; + LoweredValInfo lowerLValueExpr( IRGenContext* context, Expr* expr) @@ -4892,6 +5034,17 @@ LoweredValInfo lowerRValueExpr( return info; } +void lowerRValueExprWithDestination( + IRGenContext* context, + LoweredValInfo destination, + Expr* expr) +{ + DestinationDrivenRValueExprLoweringVisitor visitor; + visitor.context = context; + visitor.destination = destination; + visitor.dispatch(expr); +} + struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> { IRGenContext* context; @@ -5464,6 +5617,14 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // if( auto expr = stmt->expression ) { + if (context->returnDestination.flavor != LoweredValInfo::Flavor::None) + { + // If this function should return via a __ref parameter, do that and return void. + lowerRValueExprWithDestination(context, context->returnDestination, expr); + getBuilder()->emitReturn(); + return; + } + // If the AST `return` statement had an expression, then we // need to lower it to the IR at this point, both to // compute its value and (in case we are returning a @@ -6125,6 +6286,30 @@ IRInst* getAddress( return nullptr; } +void assignExpr( + IRGenContext* context, + const LoweredValInfo& inLeft, + Expr* rightExpr, + SourceLoc assignmentLoc) +{ + auto left = tryGetAddress(context, inLeft, TryGetAddressMode::Default); + IRBuilderSourceLocRAII locRAII(context->irBuilder, assignmentLoc); + switch (left.flavor) + { + case LoweredValInfo::Flavor::Ptr: + { + lowerRValueExprWithDestination(context, left, rightExpr); + } + break; + default: + { + auto right = lowerRValueExpr(context, rightExpr); + assign(context, inLeft, right); + } + break; + } +} + void assign( IRGenContext* context, LoweredValInfo const& inLeft, @@ -7228,6 +7413,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContextStorage.thisType = outerContext->thisType; subContextStorage.thisTypeWitness = outerContext->thisTypeWitness; + + subContextStorage.returnDestination = LoweredValInfo(); } IRBuilder* getBuilder() { return &subBuilderStorage; } @@ -7381,9 +7568,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if( auto initExpr = decl->initExpr ) { - auto initVal = lowerRValueExpr(context, initExpr); - - assign(context, varVal, initVal); + assignExpr(context, varVal, initExpr, decl->loc); } context->setGlobalValue(decl, varVal); @@ -7395,7 +7580,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { return Slang::getInterfaceRequirementKey(context, requirementDecl); } - + LoweredValInfo visitAssocTypeDecl(AssocTypeDecl* decl) { SLANG_ASSERT(decl->parentDecl != nullptr); @@ -8700,6 +8885,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> paramVal = LoweredValInfo::ptr(irParam); + if (paramInfo.isReturnDestination) + subContext->returnDestination = paramVal; + // TODO: We might want to copy the pointed-to value into // a temporary at the start of the function, and then copy // back out at the end, so that we don't have to worry @@ -8815,14 +9003,19 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto constructorDecl = as<ConstructorDecl>(decl); if (constructorDecl) { - auto thisVar = subContext->irBuilder->emitVar(irResultType); - subContext->thisVal = LoweredValInfo::ptr(thisVar); - - // For class-typed objects, we need to allocate it from heap. - if (isClassType(irResultType)) + if (subContext->returnDestination.flavor != LoweredValInfo::Flavor::None) + subContext->thisVal = subContext->returnDestination; + else { - auto allocatedObj = subContext->irBuilder->emitAllocObj(irResultType); - subContext->irBuilder->emitStore(thisVar, allocatedObj); + auto thisVar = subContext->irBuilder->emitVar(irResultType); + subContext->thisVal = LoweredValInfo::ptr(thisVar); + + // For class-typed objects, we need to allocate it from heap. + if (isClassType(irResultType)) + { + auto allocatedObj = subContext->irBuilder->emitAllocObj(irResultType); + subContext->irBuilder->emitStore(thisVar, allocatedObj); + } } } @@ -8846,8 +9039,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // path in an initializer/constructor attempts // to do an early `return;`. // - subContext->irBuilder->emitReturn( - getSimpleVal(subContext, subContext->thisVal)); + if (subContext->returnDestination.flavor != LoweredValInfo::Flavor::None) + subContext->irBuilder->emitReturn(); + else + { + subContext->irBuilder->emitReturn( + getSimpleVal(subContext, subContext->thisVal)); + } } else if (as<IRVoidType>(irResultType)) { diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index b270ba713..c3eba8c58 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -5299,6 +5299,13 @@ namespace Slang return expr; } + static NodeBase* parseReturnValExpr(Parser* parser, void* /*userData*/) + { + ReturnValExpr* expr = parser->astBuilder->create<ReturnValExpr>(); + expr->scope = parser->currentScope; + return expr; + } + static Expr* parseBoolLitExpr(Parser* parser, bool value) { BoolLiteralExpr* expr = parser->astBuilder->create<BoolLiteralExpr>(); @@ -7314,6 +7321,7 @@ namespace Slang _makeParseExpr("this", parseThisExpr), _makeParseExpr("true", parseTrueExpr), _makeParseExpr("false", parseFalseExpr), + _makeParseExpr("__return_val", parseReturnValExpr), _makeParseExpr("nullptr", parseNullPtrExpr), _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), diff --git a/tests/bindings/hlsl-to-vulkan-shift-rw-structured.hlsl b/tests/bindings/hlsl-to-vulkan-shift-rw-structured.hlsl index 2fcbdf77c..17d2036e7 100644 --- a/tests/bindings/hlsl-to-vulkan-shift-rw-structured.hlsl +++ b/tests/bindings/hlsl-to-vulkan-shift-rw-structured.hlsl @@ -1,10 +1,10 @@ //TEST:SIMPLE(filecheck=CHECK):-target glsl -profile glsl_450 -entry MainCs -stage compute -fvk-b-shift 0 0 -fvk-s-shift 14 0 -fvk-t-shift 30 0 -fvk-u-shift 158 0 -// CHECK: layout(std430, binding = 159) buffer -// CHECK: } g_ByteBuffer +// CHECK-DAG: layout(std430, binding = 159) buffer +// CHECK-DAG: } g_ByteBuffer -// CHECK: layout(std430, binding = 158) buffer +// CHECK-DAG: layout(std430, binding = 158) buffer RWStructuredBuffer<uint> g_OutputCullBits; RWByteAddressBuffer g_ByteBuffer; diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-assign.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-assign.slang index 7f5be6243..49fef7a4c 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-assign.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-assign.slang @@ -1,7 +1,7 @@ // hit-object-assign.slang //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile glsl_460+GL_EXT_ray_tracing -O0 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile glsl_460+GL_EXT_ray_tracing -O0 -line-directive-mode none //DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -output-using-type -profile sm_6_5 -nvapi-slot u0 //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query @@ -10,6 +10,9 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<uint> outputBuffer; +// SPIRV: OpHitObjectRecordMissNV +// SPIRV: OpHitObjectIsMissNV + void rayGenerationMain() { int2 launchID = int2(DispatchRaysIndex().xy); diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang index e38b29446..8fea9cf67 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-hit.slang @@ -1,7 +1,7 @@ // hit-object-make-hit.slang //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -18,8 +18,16 @@ struct SomeValues float b; }; + uint calcValue(HitObject hit) { + // SPIRV-DAG: OpHitObjectIsHitNV + // SPIRV: OpHitObjectGetInstanceCustomIndexNV + // SPIRV: OpHitObjectGetInstanceIdNV + // SPIRV: OpHitObjectGetGeometryIndexNV + // SPIRV: OpHitObjectGetPrimitiveIndexNV + // SPIRV: OpHitObjectGetHitKindNV + // SPIRV: OpHitObjectIsMissNV uint r = 0; if (hit.IsHit()) @@ -72,6 +80,7 @@ void rayGenerationMain() uint r = 0; { + // SPIRV-DAG: OpHitObjectRecordHitNV HitObject hit = HitObject::MakeHit(0, scene, idx, idx * 2, idx * 3, hitKind, ray, someValues); r = calcValue(hit); diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-miss.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-miss.slang index 421063987..9aea89573 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-miss.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-miss.slang @@ -1,7 +1,7 @@ // hit-object-make-miss.slang -//TEST:SIMPLE: -target dxil -entry computeMain -stage compute -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry computeMain -stage compute -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST:COMPARE_COMPUTE_EX:-slang -compute -dx12 -output-using-type -profile sm_6_5 -nvapi-slot u0 //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query @@ -10,19 +10,19 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<uint> outputBuffer; -[numthreads(4, 1, 1)] -void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) -{ - int idx = int(dispatchThreadID.x); +void rayGenerationMain() +{ + int idx = DispatchRaysIndex().x; RayDesc ray; ray.Origin = float3(idx, 0, 0); ray.TMin = 0.01f; ray.Direction = float3(0, 1, 0); ray.TMax = 1e4f; - + // SPIRV: OpHitObjectRecordMissNV HitObject hit = HitObject::MakeMiss(idx, ray); - + + // SPIRV: OpHitObjectIsMissNV int r = int(hit.IsMiss()); outputBuffer[idx] = r; diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-nop.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-nop.slang index b1d72c47e..e8c88e1ad 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-nop.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-make-nop.slang @@ -1,7 +1,7 @@ // hit-object-make-nop.slang //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -16,6 +16,7 @@ void rayGenerationMain() int idx = launchID.x; + // SPIRV: OpHitObjectRecordEmptyNV HitObject hit = HitObject::MakeNop(); outputBuffer[idx] = uint(hit.IsNop()); diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-output.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-output.slang index e8afaf217..e06e63693 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-output.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-output.slang @@ -4,7 +4,7 @@ // as function results (including `out` parameters) //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -57,7 +57,8 @@ HitObject myTraceRay(uint idx) uint multiplierForGeometryContributionToHitGroupIndex = 4; uint missShaderIndex = 0; - HitObject hit = HitObject::TraceRay(scene, + // SPIRV-DAG: OpHitObjectTraceRayNV + return HitObject::TraceRay(scene, rayFlags, instanceInclusionMask, rayContributionToHitGroupIndex, @@ -65,8 +66,6 @@ HitObject myTraceRay(uint idx) missShaderIndex, ray, payload); - - return hit; } void copyHitObjectHandle( @@ -94,16 +93,20 @@ void rayGenerationMain() accumulate(r, hit); +#if 0 // cannot support this right now HitObject hit2; copyHitObjectHandle(hit2, hit); accumulate(r, hit2); +#else + accumulate(r, hit); + +#endif - HitObject hitBackup = hit; myMakeMiss(idx, hit); accumulate(r, hit); - accumulate(r, hitBackup); + accumulate(r, hit); outputBuffer[idx] = r; } diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang index ed83b8d47..8498b9304 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-reorder-thread.slang @@ -1,7 +1,7 @@ // hit-object-reorder-thread.slang //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_5 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -62,6 +62,7 @@ void rayGenerationMain() uint multiplierForGeometryContributionToHitGroupIndex = 4; uint missShaderIndex = 0; + // SPIRV: OpHitObjectTraceRayNV HitObject hit = HitObject::TraceRay(scene, rayFlags, instanceInclusionMask, @@ -70,42 +71,47 @@ void rayGenerationMain() missShaderIndex, ray, someValues); - + uint r = calcValue(hit); - + // SPIRV: OpReorderThreadWithHitObjectNV ReorderThread( hit ); // Change the payload SomeValues otherValues = { idx * -1, idx * 4.0f }; - + // Now Invoke to cast another ray, with the new payload + // SPIRV: OpHitObjectExecuteShaderNV HitObject::Invoke( scene, hit, otherValues ); r += calcValue(hit); // !!! TODO(JS) !!! // NOTE! If I enable this I end up with a recursive failure in AST traversal, if - // otherValues is redefined. - - // Reorder + // otherValues is redefined. + + // Reorder + // SPIRV: OpReorderThreadWithHitObjectNV ReorderThread(hit, uint(idx & 3), 2); // Change the payload otherValues = { idx * -2, idx * 8.0f }; - + // Now Invoke to cast another ray, with the new payload + // SPIRV: OpHitObjectExecuteShaderNV HitObject::Invoke( scene, hit, otherValues ); r += calcValue(hit); - // Reorder + // Reorder + // SPIRV: OpReorderThreadWithHintNV ReorderThread(uint(idx & 1), 1); // Change the payload otherValues = { idx * -4, idx * 16.0f }; - + // Now Invoke to cast another ray, with the new payload + // SPIRV: OpHitObjectExecuteShaderNV HitObject::Invoke( scene, hit, otherValues ); r += calcValue(hit); diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang index f57ecf02a..4e8d52e10 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-motion-ray.slang @@ -2,7 +2,7 @@ // Motion rays not supported on HLSL impl currently //DISABLE_TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_6 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -64,7 +64,7 @@ void rayGenerationMain() uint rayContributionToHitGroupIndex = 0; uint multiplierForGeometryContributionToHitGroupIndex = 4; uint missShaderIndex = 0; - + // SPIRV: OpHitObjectTraceRayMotionNV HitObject hit = HitObject::TraceMotionRay(scene, rayFlags, instanceInclusionMask, diff --git a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang index 63ae4c957..c1a29d647 100644 --- a/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang +++ b/tests/hlsl-intrinsic/shader-execution-reordering/hit-object-trace-ray.slang @@ -1,7 +1,7 @@ // hit-object-trace-ray.slang //TEST:SIMPLE: -target dxil -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -DNV_SHADER_EXTN_SLOT=u0 -//TEST:SIMPLE: -target glsl -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none +//TEST:SIMPLE(filecheck=SPIRV): -target spirv -entry rayGenerationMain -stage raygeneration -profile sm_6_5 -line-directive-mode none //DISABLE_TEST(compute):COMPARE_COMPUTE:-d3d12 -output-using-type -use-dxil -profile sm_6_6 -render-feature ray-query //DISABLE_TEST(compute):COMPARE_COMPUTE:-vk -output-using-type -render-feature ray-query @@ -61,7 +61,7 @@ void rayGenerationMain() uint rayContributionToHitGroupIndex = 0; uint multiplierForGeometryContributionToHitGroupIndex = 4; uint missShaderIndex = 0; - + // SPIRV: OpHitObjectTraceRayNV HitObject hit = HitObject::TraceRay(scene, rayFlags, instanceInclusionMask, diff --git a/tests/language-feature/non-copyable-return.slang b/tests/language-feature/non-copyable-return.slang new file mode 100644 index 000000000..20330c5f9 --- /dev/null +++ b/tests/language-feature/non-copyable-return.slang @@ -0,0 +1,37 @@ +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type +//TEST(compute):SIMPLE(filecheck=GLSL): -stage compute -entry computeMain -target glsl + +// Note: spirv_by_reference is only supported for passing opaque types, so this test won't produce +// expected result on vulkan. +//DISABLED_TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type + +[__NonCopyableType] +struct MyType +{ + float x; + __init() { x = 1.0; } +} + +MyType myFunc1(float y) +{ + __return_val = MyType(); + __return_val.x += y; +} + +MyType myFunc0(float x) +{ + return myFunc1(x + 1.0); +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(int3 dispatchThreadID: SV_DispatchThreadID) +{ + let f = myFunc0(2.0); + // CHECK: 4.0 + // GLSL: void myFunc1_0(float y{{.*}}, spirv_by_reference MyType_0 {{.*}}) + // GLSL: void myFunc0_0(float x{{.*}}, spirv_by_reference MyType_0 {{.*}}) + outputBuffer[0] = f.x; +} diff --git a/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl b/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl index bf10cc2e1..4ed4c9966 100644 --- a/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl +++ b/tests/pipeline/ray-tracing/trace-ray-inline.slang.hlsl @@ -1,4 +1,8 @@ #pragma pack_matrix(column_major) +#ifdef SLANG_HLSL_ENABLE_NVAPI +#include "nvHLSLExtns.h" +#endif +#pragma warning(disable: 3557) struct SLANG_ParameterGroup_C_0 { @@ -15,7 +19,6 @@ cbuffer C_0 : register(b0) { SLANG_ParameterGroup_C_0 C_0; } - RaytracingAccelerationStructure myAccelerationStructure_0 : register(t0); RWStructuredBuffer<int > resultBuffer_0 : register(u0); @@ -67,53 +70,41 @@ void myMiss_0(inout MyRayPayload_0 payload_4) void main(uint3 tid_0 : SV_DISPATCHTHREADID) { uint index_0 = tid_0.x; - RayQuery<int(512) > query_0; - MyRayPayload_0 payload_5; payload_5.value_1 = int(-1); - RayDesc ray_0 = { C_0.origin_0, C_0.tMin_0, C_0.direction_0, C_0.tMax_0 }; - + RayQuery<512U > query_0; query_0.TraceRayInline(myAccelerationStructure_0, C_0.rayFlags_0, C_0.instanceMask_0, ray_0); - MyProceduralHitAttrs_0 committedProceduralAttrs_0; - for(;;) { - bool _S1 = query_0.Proceed(); - if(!_S1) { - break; } - + uint _S2 = query_0.CandidateType(); MyProceduralHitAttrs_0 committedProceduralAttrs_1; - switch(query_0.CandidateType()) + switch(_S2) { case 1U: { MyProceduralHitAttrs_0 candidateProceduralAttrs_0; candidateProceduralAttrs_0.value_0 = int(0); - float tHit_1 = 0.0; - bool _S2 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0); - if(_S2) + bool _S3 = myProceduralIntersection_0(tHit_1, candidateProceduralAttrs_0); + if(_S3) { - bool _S3 = myProceduralAnyHit_0(payload_5); - if(_S3) + bool _S4 = myProceduralAnyHit_0(payload_5); + if(_S4) { query_0.CommitProceduralPrimitiveHit(tHit_1); - MyProceduralHitAttrs_0 _S4 = candidateProceduralAttrs_0; + MyProceduralHitAttrs_0 _S5 = candidateProceduralAttrs_0; if(C_0.shouldStopAtFirstHit_0 != 0U) { query_0.Abort(); } - else - {} - - committedProceduralAttrs_1 = _S4; + committedProceduralAttrs_1 = _S5; } else { @@ -128,20 +119,15 @@ void main(uint3 tid_0 : SV_DISPATCHTHREADID) } case 0U: { - bool _S5 = myTriangleAnyHit_0(payload_5); - if(_S5) + bool _S6 = myTriangleAnyHit_0(payload_5); + if(_S6) { query_0.CommitNonOpaqueTriangleHit(); if(C_0.shouldStopAtFirstHit_0 != 0U) { query_0.Abort(); } - else - {} } - else - {} - committedProceduralAttrs_1 = committedProceduralAttrs_0; break; } @@ -151,11 +137,10 @@ void main(uint3 tid_0 : SV_DISPATCHTHREADID) break; } } - committedProceduralAttrs_0 = committedProceduralAttrs_1; } - - switch(query_0.CommittedStatus()) + uint _S7 = query_0.CommittedStatus(); + switch(_S7) { case 1U: { |
