diff options
| author | Yong He <yonghe@outlook.com> | 2017-11-07 19:09:40 -0500 |
|---|---|---|
| committer | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-11-07 16:09:40 -0800 |
| commit | 6e591ada0eb652c320bba4bd8a46cd579946df01 (patch) | |
| tree | 768229fb26204b6b0a89201d9b14c32e9203c098 | |
| parent | 939688e963fde7a0485f210ef2674c27692021a4 (diff) | |
Support generic interface methods (#251)
* improve diagnostic messages and prevent fatal errors from crashing the compiler.
* fix top level exception catching.
* spelling fix
* change wording of invalidSwizzleExpr diagnostic
* add speculative GenericsApp expr parsing
* add new test case of cascading generics call.
* Fixing bugs in compiling cascaded generic function calls.
Add implementation of DeclaredSubTypeWitness::SubstituteImpl()
This is not needed by the type checker, but needed by IR specialization. When input source contains cascading generic function call, the arguments to `specialize` instruction is currently represented as a substitution. The arg values of this subsittution can be a `DeclaredSubTypeWitness` when a generic function uses one of its generic parameter to specialize another generic function. When the top level generics function is being specialized, this substitution argument, which is a `DeclaredSubTypeWitness`, needs to be substituted with the witness that used to specialize the top level function in the specialized specialize instruction as well.
* add a test case for cascading generic function call.
* parser bug fix
* fixes #255
* add test case for issue #255
* Generate missing `specialize` instruction when calling a generic method from an interface constraint.
When calling a generic method via an interface, we should be generating the following ir:
...
f = lookup_interface_method(...)
f_s = specailize(f, declRef)
...
This commit fixes this `emitFuncRef` function to emit the needed `specialize` instruction.
* fixes #260
This fix follows the second apporach in the disucssion. It generated mangled name for specialized functions by appending new substitution type names to the original mangled name.
* Disabling removing and re-inserting specailized functions in getSpecalizeFunc()
I am not sure why it is needed, it seems HLSL and GLSL backends are generating forward declarations anyways, so the order of functions in IRModule shouldn't matter.
* cleanup and complete test cases.
* fix warnings
24 files changed, 524 insertions, 65 deletions
diff --git a/source/core/exception.h b/source/core/exception.h index aedb62add..fc7aa48e2 100644 --- a/source/core/exception.h +++ b/source/core/exception.h @@ -121,6 +121,17 @@ namespace Slang { } }; + + class AbortCompilationException : public Exception + { + public: + AbortCompilationException() + {} + AbortCompilationException(const String & message) + : Exception(message) + { + } + }; } #endif
\ No newline at end of file diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 6bb7c232f..320b22bdb 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -1588,6 +1588,44 @@ namespace Slang return true; } + bool doesGenericSignatureMatchRequirement( + GenericDecl * genDecl, + DeclRef<GenericDecl> requirementGenDecl) + { + // TODO: genDecl should be a DeclRef to capture the environment generic variables needed to get + // a concrete type for a generic constraint super type (e.g. when this member belongs to a generic type) + + if (genDecl->Members.Count() != requirementGenDecl.getDecl()->Members.Count()) + return false; + for (UInt i = 0; i < genDecl->Members.Count(); i++) + { + auto genMbr = genDecl->Members[i]; + auto requiredGenMbr = genDecl->Members[i]; + if (auto genTypeMbr = genMbr.As<GenericTypeParamDecl>()) + { + if (auto requiredGenTypeMbr = requiredGenMbr.As<GenericTypeParamDecl>()) + { + } + else + return false; + } + else if (auto genTypeConstraintMbr = genMbr.As<GenericTypeConstraintDecl>()) + { + if (auto requiredTypeConstraintMbr = requiredGenMbr.As<GenericTypeConstraintDecl>()) + { + if (!genTypeConstraintMbr->sup->Equals(requiredTypeConstraintMbr->sup)) + { + return false; + } + } + else + return false; + } + } + return doesMemberSatisfyRequirement(genDecl->inner.Ptr(), + DeclRef<Decl>(requirementGenDecl.getDecl()->inner.Ptr(), requirementGenDecl.substitutions)); + } + // Does the given `memberDecl` work as an implementation // to satisfy the requirement `requiredMemberDeclRef` // from an interface? @@ -1634,6 +1672,13 @@ namespace Slang requiredInitDecl); } } + else if (auto genDecl = dynamic_cast<GenericDecl*>(memberDecl)) + { + if (auto requiredGenDeclRef = requiredMemberDeclRef.As<GenericDecl>()) + { + return doesGenericSignatureMatchRequirement(genDecl, requiredGenDeclRef); + } + } else if (auto subStructTypeDecl = dynamic_cast<AggTypeDecl*>(memberDecl)) { // this is a sub type (e.g. nested struct declaration) in an aggregate type @@ -4462,9 +4507,13 @@ namespace Slang DeclRef<Decl> innerDeclRef(GetInner(baseGenericRef), subst); + RefPtr<Expr> base; + if (auto mbrExpr = baseExpr.As<MemberExpr>()) + base = mbrExpr->BaseExpression; + return ConstructDeclRefExpr( innerDeclRef, - nullptr, + base, originalExpr->loc); } @@ -5738,7 +5787,6 @@ namespace Slang // declarations with the same name, so this becomes a specialized case of // overload resolution. - // Start by checking the base expression and arguments. auto& baseExpr = genericAppExpr->FunctionExpr; baseExpr = CheckTerm(baseExpr); @@ -6095,9 +6143,9 @@ namespace Slang case 'w': case 'a': elementIndex = 3; break; default: // An invalid character in the swizzle is an error - if (!isRewriteMode()) + if (!isRewriteMode() && !anyError) { - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "invalid component name for swizzle"); + getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); } anyError = true; continue; @@ -6109,9 +6157,9 @@ namespace Slang // Make sure the index is in range for the source type if (elementIndex >= limitElement) { - if (!isRewriteMode()) + if (!isRewriteMode() && !anyError) { - getSink()->diagnose(swizExpr, Diagnostics::unimplemented, "swizzle component out of range for type"); + getSink()->diagnose(swizExpr, Diagnostics::invalidSwizzleExpr, swizzleText, baseElementType->ToString()); } anyError = true; continue; diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 10b2dbd1e..7f27e43e8 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -191,6 +191,7 @@ DIAGNOSTIC(30035, Error, componentOverloadTypeMismatch, "'$0': type of overloade DIAGNOSTIC(30041, Error, bitOperationNonIntegral, "bit operation: operand must be integral type.") DIAGNOSTIC(30047, Error, argumentExpectedLValue, "argument passed to parameter '$0' must be l-value.") DIAGNOSTIC(30051, Error, invalidValueForArgument, "invalid value for argument '$0'") +DIAGNOSTIC(30052, Error, invalidSwizzleExpr, "invalid swizzle pattern '$0' on type '$1'") DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of apparent call must have function type.") // 303xx: interfaces and associated types diff --git a/source/slang/diagnostics.cpp b/source/slang/diagnostics.cpp index f08554b52..18e5f9a4d 100644 --- a/source/slang/diagnostics.cpp +++ b/source/slang/diagnostics.cpp @@ -209,7 +209,7 @@ void DiagnosticSink::diagnoseImpl(SourceLoc const& pos, DiagnosticInfo const& in if (diagnostic.severity >= Severity::Fatal) { // TODO: figure out a better policy for aborting compilation - throw InvalidOperationException(); + throw AbortCompilationException(); } } diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 9fda630bf..332125fc6 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -4754,7 +4754,7 @@ emitDeclImpl(decl, nullptr); void readSimpleIntVal() { int c = peek(); - if(isDigit(c)) + if(isDigit((char)c)) { get(); } diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 677ccedd3..5b42407c2 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -469,12 +469,12 @@ static LegalVal legalizeFieldAddress( } } SLANG_UNEXPECTED("didn't find tuple element"); - return LegalVal(); + UNREACHABLE_RETURN(LegalVal()); } default: SLANG_UNEXPECTED("unhandled"); - return LegalVal(); + UNREACHABLE_RETURN(LegalVal()); } } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 439bc7797..b1af6521c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -55,15 +55,15 @@ namespace Slang } } - void IRUse::set(IRValue* usedValue) + void IRUse::set(IRValue* usedVal) { // clear out the old value - if (usedValue) + if (usedVal) { *prevLink = nextUse; } - init(user, usedValue); + init(user, usedVal); } void IRUse::clear() @@ -3750,8 +3750,6 @@ namespace Slang { if( auto subtypeWitness = dynamic_cast<SubtypeWitness*>(val) ) { - // We need to look up the IR value that represents the - // given subtype witness. String mangledName = getMangledNameForConformanceWitness( subtypeWitness->sub, subtypeWitness->sup); @@ -3948,8 +3946,11 @@ namespace Slang // has already been made. To do that we will need to // compute the mangled name of the specialized function, // so that we can look for existing declarations. - - String specMangledName = getMangledName(specDeclRef); + String specMangledName; + if (genericFunc->genericDecl == specDeclRef.decl) + specMangledName = getMangledName(specDeclRef); + else + specMangledName = mangleSpecializedFuncName(genericFunc->mangledName, specDeclRef.substitutions); // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, @@ -3992,8 +3993,8 @@ namespace Slang // // TODO: This shouldn't be needed, if we introduce a sorting // step before we emit code. - specFunc->removeFromParent(); - specFunc->insertAfter(genericFunc); + //specFunc->removeFromParent(); + //specFunc->insertAfter(genericFunc); // At this point we've created a new non-generic function, // which means we should add it to our work list for @@ -4026,8 +4027,20 @@ namespace Slang auto keyDeclRef = ((IRDeclRef*) requirementKey)->declRef; // If the keys don't match, continue with the next entry. - if(!keyDeclRef.Equals(requirementDeclRef)) - continue; + if (!keyDeclRef.Equals(requirementDeclRef)) + { + // requirementDeclRef may be pointing to the inner decl of a generic decl + // in this case we compare keyDeclRef against the parent decl of requiredDeclRef + if (auto genRequiredDeclRef = requirementDeclRef.GetParent().As<GenericDecl>()) + { + if (!keyDeclRef.Equals(genRequiredDeclRef)) + { + continue; + } + } + else + continue; + } // If the keys matched, then we use the value from // this entry. @@ -4178,7 +4191,6 @@ namespace Slang // Use the witness table to look up the value that // satisfies the requirement. auto satisfyingVal = findWitnessVal(witnessTable, requirementDeclRef); - // We expect to always find something, but lets just // be careful here. if(!satisfyingVal) diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 2c64c539b..4fabd6a81 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -314,6 +314,13 @@ LoweredValInfo emitDeclRef( IRGenContext* context, DeclRef<Decl> declRef); +// Emit necessary `specialize` instruction needed by a declRef. +// This is currently used by emitDeclRef() and emitFuncRef() +LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, + LoweredValInfo loweredDecl, // the lowered value of the inner decl + DeclRef<Decl> declRef // the full decl ref containing substitutions +); + IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered); @@ -488,10 +495,11 @@ LoweredValInfo emitFuncRef( RefPtr<Type> type = funcExpr->type; - return LoweredValInfo::simple(context->irBuilder->emitLookupInterfaceMethodInst( + auto loweredVal = LoweredValInfo::simple(context->irBuilder->emitLookupInterfaceMethodInst( type, baseMemberDeclRef, funcDeclRef)); + return maybeEmitSpecializeInst(context, loweredVal, funcDeclRef); } } } @@ -2893,6 +2901,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(irFunc); } + LoweredValInfo visitGenericDecl(GenericDecl * genDecl) + { + if (auto innerFuncDecl = genDecl->inner->As<FuncDecl>()) + return lowerFuncDecl(innerFuncDecl); + SLANG_RELEASE_ASSERT(false); + UNREACHABLE_RETURN(LoweredValInfo()); + } LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) { @@ -2997,6 +3012,11 @@ RefPtr<Val> lowerSubstitutionArg( } else if (auto declaredSubtypeWitness = dynamic_cast<DeclaredSubtypeWitness*>(val)) { + // We do not have a concrete witness table yet for a GenericTypeConstraintDecl witness + + if (declaredSubtypeWitness->declRef.As<GenericTypeConstraintDecl>()) + return val; + // We need to look up the IR-level representation of the witness // (which is a witness table). @@ -3073,9 +3093,16 @@ LoweredValInfo emitDeclRef( // unspecialized declaration. LoweredValInfo loweredDecl = ensureDecl(context, declRef.getDecl()); + return maybeEmitSpecializeInst(context, loweredDecl, declRef); +} + +LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, + LoweredValInfo loweredDecl, + DeclRef<Decl> declRef) +{ // If this declaration reference doesn't involve any specializations, // then we are done at this point. - if(!hasGenericSubstitutions(declRef.substitutions)) + if (!hasGenericSubstitutions(declRef.substitutions)) return loweredDecl; auto val = getSimpleVal(context, loweredDecl); @@ -3089,7 +3116,7 @@ LoweredValInfo emitDeclRef( RefPtr<Type> type; - if(auto declType = val->getType()) + if (auto declType = val->getType()) { type = declType->Substitute(declRef.substitutions).As<Type>(); } @@ -3102,6 +3129,7 @@ LoweredValInfo emitDeclRef( declRef)); } + static void lowerEntryPointToIR( IRGenContext* context, EntryPointRequest* entryPointRequest) diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index dca48f671..38c2b29a8 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -350,6 +350,24 @@ namespace Slang DeclRef<Decl>(declRef.decl, declRef.substitutions)); } + String mangleSpecializedFuncName(String baseName, RefPtr<Substitutions> subst) + { + ManglingContext context; + emitRaw(&context, baseName.Buffer()); + emitRaw(&context, "_G"); + while (subst) + { + if (auto genSubst = subst.As<GenericSubstitution>()) + { + for (auto a : genSubst->args) + emitVal(&context, a); + break; + } + subst = subst->outer; + } + return context.sb.ProduceString(); + } + String getMangledName(Decl* decl) { return getMangledName(makeDeclRef(decl)); diff --git a/source/slang/mangle.h b/source/slang/mangle.h index 60a1dff9e..29101a926 100644 --- a/source/slang/mangle.h +++ b/source/slang/mangle.h @@ -11,7 +11,7 @@ namespace Slang String getMangledName(Decl* decl); String getMangledName(DeclRef<Decl> const & declRef); String getMangledName(DeclRefBase const & declRef); - + String mangleSpecializedFuncName(String baseName, RefPtr<Substitutions> subst); String getMangledNameForConformanceWitness( Type* sub, Type* sup); diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 595d58d7e..eb02d98c5 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -77,7 +77,7 @@ namespace Slang , sink(sink) , outerScope(outerScope) {} - + Parser(const Parser & other) = default; Session* getSession() { return translationUnit->compileRequest->mSession; @@ -1396,6 +1396,69 @@ namespace Slang return genericApp; } + static bool isGenericName(Parser* parser, Name* name) + { + auto lookupResult = lookUp( + parser->getSession(), + nullptr, // no semantics visitor available yet + name, + parser->currentScope); + if (!lookupResult.isValid() || lookupResult.isOverloaded()) + return false; + + auto decl = lookupResult.item.declRef.getDecl(); + if (auto genericDecl = dynamic_cast<GenericDecl*>(decl)) + { + return true; + } + else + { + return false; + } + } + + static RefPtr<Expr> tryParseGenericApp( + Parser* parser, + RefPtr<Expr> base) + { + Name * baseName = nullptr; + if (auto varExpr = base->As<VarExpr>()) + baseName = varExpr->name; + // if base is a known generics, parse as generics + if (baseName && isGenericName(parser, baseName)) + return parseGenericApp(parser, base); + + // otherwise, we speculate as generics, and fallback to comparison when parsing failed + TokenSpan tokenSpan; + tokenSpan.mBegin = parser->tokenReader.mCursor; + tokenSpan.mEnd = parser->tokenReader.mEnd; + DiagnosticSink newSink; + newSink.sourceManager = parser->sink->sourceManager; + Parser newParser(*parser); + newParser.sink = &newSink; + auto speculateParseRs = parseGenericApp(&newParser, base); + if (newSink.errorCount == 0) + { + // disambiguate based on FOLLOW set + switch (peekTokenType(&newParser)) + { + case TokenType::Dot: + case TokenType::LParent: + case TokenType::RParent: + case TokenType::RBracket: + case TokenType::Colon: + case TokenType::Comma: + case TokenType::QuestionMark: + case TokenType::Semicolon: + case TokenType::OpEql: + case TokenType::OpNeq: + { + return parseGenericApp(parser, base); + } + } + } + return base; + } static RefPtr<Expr> parseMemberType(Parser * parser, RefPtr<Expr> base) { RefPtr<MemberExpr> memberExpr = new MemberExpr(); @@ -2624,28 +2687,6 @@ namespace Slang return stmt; } - static bool isGenericName(Parser* parser, Name* name) - { - auto lookupResult = lookUp( - parser->getSession(), - nullptr, // no semantics visitor available yet - name, - parser->currentScope); - if(!lookupResult.isValid() || lookupResult.isOverloaded()) - return false; - - auto decl = lookupResult.item.declRef.getDecl(); - if( auto genericDecl = dynamic_cast<GenericDecl*>(decl) ) - { - return true; - } - else - { - return false; - } - } - - static bool isTypeName(Parser* parser, Name* name) { auto lookupResult = lookUp( @@ -3367,24 +3408,18 @@ namespace Slang } #endif } - + // We *might* be looking at an application of a generic to arguments, // but we need to disambiguate to make sure. static RefPtr<Expr> maybeParseGenericApp( Parser* parser, // TODO: need to support more general expressions here - RefPtr<VarExpr> base) + RefPtr<Expr> base) { if(peekTokenType(parser) != TokenType::OpLess) return base; - - if(!isGenericName(parser, base->name)) - return base; - - // Okay, seems likely that we are looking at a generic app - - return parseGenericApp(parser, base); + return tryParseGenericApp(parser, base); } static RefPtr<Expr> parsePrefixExpr(Parser* parser); @@ -3769,7 +3804,10 @@ namespace Slang parser->ReadToken(TokenType::Dot); memberExpr->name = expectIdentifier(parser).name; - expr = memberExpr; + if (peekTokenType(parser) == TokenType::OpLess) + expr = maybeParseGenericApp(parser, memberExpr); + else + expr = memberExpr; } break; } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 1cda8e7d0..fe24fbd19 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -246,13 +246,13 @@ int CompileRequest::executeActionsInner() // Do some cleanup on settings specified by user. // In particular, we want to propagate flags from the overall request down to // each translation unit. - for( auto& translationUnit : translationUnits ) + for (auto& translationUnit : translationUnits) { translationUnit->compileFlags |= compileFlags; // However, the "no checking" flag shouldn't be applied to // any translation unit that is native Slang code. - if( translationUnit->sourceLanguage == SourceLanguage::Slang ) + if (translationUnit->sourceLanguage == SourceLanguage::Slang) { translationUnit->compileFlags &= ~SLANG_COMPILE_FLAG_NO_CHECKING; } @@ -282,7 +282,7 @@ int CompileRequest::executeActionsInner() // a pass-through compilation. // // Note that we *do* perform output generation as normal in pass-through mode. - if( passThrough == PassThroughMode::None ) + if (passThrough == PassThroughMode::None) { // Parse everything from the input files requested for (auto& translationUnit : translationUnits) @@ -316,7 +316,7 @@ int CompileRequest::executeActionsInner() return 1; } } - + // If command line specifies to skip codegen, we exit here. // Note: this is a debugging option. if (shouldSkipCodegen) @@ -901,7 +901,7 @@ SLANG_API int spCompile( { auto req = REQ(request); -#if 0 +#if !defined(SLANG_DEBUG_INTERNAL_ERROR) // By default we'd like to catch as many internal errors as possible, // and report them to the user nicely (rather than just crash their // application). Internally Slang currently uses exceptions for this. diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 5bf8a1b09..91e9d0994 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -97,4 +97,17 @@ <ExpandedItem>usedValue</ExpandedItem> </Expand> </Type> + <Type Name="Slang::IRModule"> + <Expand> + <Item Name="session">session</Item> + <LinkedListItems> + <HeadPointer>firstGlobalValue</HeadPointer> + <NextPointer>nextGlobalValue</NextPointer> + <ValueNode>this</ValueNode> + </LinkedListItems> + </Expand> + </Type> + <Type Name="Slang::IRGlobalValue"> + <DisplayString>{{IRGlobalValue {op} {mangledName}}}</DisplayString> + </Type> </AutoVisualizer>
\ No newline at end of file diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 9025c545a..4e1778e6e 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1527,6 +1527,61 @@ void Type::accept(IValVisitor* visitor, void* extra) && declRef.Equals(otherWitness->declRef); } + RefPtr<Val> DeclaredSubtypeWitness::SubstituteImpl(Substitutions* subst, int * ioDiff) + { + DeclRef<GenericTypeParamDecl> genParamDeclRef; + if (auto subDeclRefType = this->sub.As<DeclRefType>()) + { + genParamDeclRef = subDeclRefType->declRef.As<GenericTypeParamDecl>(); + } + if (!genParamDeclRef) + return this; + auto genParamDecl = genParamDeclRef.getDecl(); + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) + { + if (auto genericSubst = dynamic_cast<GenericSubstitution*>(s)) + { + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genParamDecl->ParentDecl) + continue; + bool found = false; + int index = 0; + for (auto m : genericDecl->Members) + { + if (m.Ptr() == genParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + found = true; + break; + } + else if (auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else + { + } + } + if (found) + { + auto ordinaryParamCount = genericDecl->getMembersOfType<GenericTypeParamDecl>().Count() + + genericDecl->getMembersOfType<GenericValueParamDecl>().Count(); + SLANG_ASSERT(ordinaryParamCount + index < genericSubst->args.Count()); + return genericSubst->args[ordinaryParamCount + index]; + } + } + } + return this; + } + String DeclaredSubtypeWitness::ToString() { StringBuilder sb; @@ -1655,4 +1710,16 @@ void Type::accept(IValVisitor* visitor, void* extra) return false; } + RefPtr<GenericSubstitution> getGenericSubstitution(RefPtr<Substitutions> subst) + { + auto p = subst.Ptr(); + while (p) + { + if (auto genSubst = dynamic_cast<GenericSubstitution*>(p)) + return genSubst; + p = p->outer.Ptr(); + } + return nullptr; + } + } diff --git a/source/slang/syntax.h b/source/slang/syntax.h index bda43b336..46beca2d9 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1160,7 +1160,7 @@ namespace Slang RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry); void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> subst); bool hasGenericSubstitutions(RefPtr<Substitutions> subst); - + RefPtr<GenericSubstitution> getGenericSubstitution(RefPtr<Substitutions> subst); } // namespace Slang #endif
\ No newline at end of file diff --git a/source/slang/val-defs.h b/source/slang/val-defs.h index 4513873e3..5a370d34a 100644 --- a/source/slang/val-defs.h +++ b/source/slang/val-defs.h @@ -95,6 +95,7 @@ RAW( virtual bool EqualsVal(Val* val) override; virtual String ToString() override; virtual int GetHashCode() override; + virtual RefPtr<Val> SubstituteImpl(Substitutions * subst, int * ioDiff) override; ) END_SYNTAX_CLASS() diff --git a/tests/bugs/nested-generics-call.slang b/tests/bugs/nested-generics-call.slang new file mode 100644 index 000000000..2c6df59c6 --- /dev/null +++ b/tests/bugs/nested-generics-call.slang @@ -0,0 +1,34 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +interface IBase +{ + float get(); +} + +struct BaseImpl : IBase +{ + float get() { return 1.0; } +}; + +__generic<T:IBase> +float eval(T obj) +{ + return obj.get(); +} + +__generic<T : IBase> +float test(T obj) +{ + return eval<T>(obj); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + BaseImpl base; + float outVal = test<BaseImpl>(base); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/bugs/nested-generics-call.slang.expected.txt b/tests/bugs/nested-generics-call.slang.expected.txt new file mode 100644 index 000000000..cc5e55ab6 --- /dev/null +++ b/tests/bugs/nested-generics-call.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000 diff --git a/tests/bugs/nested-generics-method-call.slang b/tests/bugs/nested-generics-method-call.slang new file mode 100644 index 000000000..d1e80da57 --- /dev/null +++ b/tests/bugs/nested-generics-method-call.slang @@ -0,0 +1,38 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +interface IBase +{ + float get(); +} + +struct BaseImpl : IBase +{ + float get() { return 1.0; } +}; + +struct S +{ + __generic<T:IBase> + float eval(T obj) + { + return obj.get(); + } +}; + +__generic<T : IBase> +float test(T obj) +{ + S s; + return s.eval<T>(obj); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + BaseImpl base; + float outVal = test<BaseImpl>(base); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/bugs/nested-generics-method-call.slang.expected.txt b/tests/bugs/nested-generics-method-call.slang.expected.txt new file mode 100644 index 000000000..cc5e55ab6 --- /dev/null +++ b/tests/bugs/nested-generics-method-call.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000 diff --git a/tests/compute/generic-interface-method-simple.slang b/tests/compute/generic-interface-method-simple.slang new file mode 100644 index 000000000..7ba129492 --- /dev/null +++ b/tests/compute/generic-interface-method-simple.slang @@ -0,0 +1,48 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +interface IVertexInterpolant +{ + float2 getUV(); +} + +interface IBRDFPattern +{ + __generic<TVertexInterpolant:IVertexInterpolant> + float evalPattern(TVertexInterpolant interpolants); +} + +struct StandardVertexInterpolant : IVertexInterpolant +{ + float2 getUV() { return float2(0.5); } +}; + +struct MaterialPattern1 : IBRDFPattern +{ + float base; + __generic<TVertexInterpolant:IVertexInterpolant> + float evalPattern(TVertexInterpolant interpolants) + { + float rs = base + interpolants.getUV().x; + return rs; + } +}; + +__generic<TPattern : IBRDFPattern, TInterpolant: IVertexInterpolant> +float test(TPattern pattern, TInterpolant vertInterps) +{ + float rs = pattern.evalPattern<TInterpolant>(vertInterps); + return rs; +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + StandardVertexInterpolant vertInterp; + MaterialPattern1 mp1; + mp1.base = 0.5; + float outVal = test<MaterialPattern1, StandardVertexInterpolant>(mp1, vertInterp); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/generic-interface-method-simple.slang.expected.txt b/tests/compute/generic-interface-method-simple.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/generic-interface-method-simple.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000
\ No newline at end of file diff --git a/tests/compute/generic-interface-method.slang b/tests/compute/generic-interface-method.slang new file mode 100644 index 000000000..e4fa8cff5 --- /dev/null +++ b/tests/compute/generic-interface-method.slang @@ -0,0 +1,86 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +struct DisneyBRDFPattern +{ + float3 baseColor; + float3 normal; + float specular, metallic, roughness; + float opacity; + float3 emmissive; + float ambientOcclusion; +}; + +struct VertexPosition +{ + float3 pos; + float3 normal; + float2 uv; +}; + +struct CameraView +{ + float3 camPos; + float3 camDir; +}; + +interface IVertexInterpolant +{ + float4 getVertexColor(int index); + int getVertexColorCount(); + float2 getUV(int index); + int getUVCount(); +} + +interface IDisneyBRDFPattern +{ + __generic<TVertexInterpolant:IVertexInterpolant> + DisneyBRDFPattern evalPattern( + CameraView cam, + VertexPosition vWorld, + VertexPosition vObject, + TVertexInterpolant interpolants); +} + +struct StandardVertexInterpolant : IVertexInterpolant +{ + float4 getVertexColor(int index) { return float4(0.0); } + int getVertexColorCount() { return 0;} + float2 getUV(int index) { return float2(0.0); } + int getUVCount() {return 1; } +}; + +struct MaterialPattern1 : IDisneyBRDFPattern +{ + __generic<TVertexInterpolant:IVertexInterpolant> + DisneyBRDFPattern evalPattern( + CameraView cam, + VertexPosition vWorld, + VertexPosition vObject, + TVertexInterpolant interpolants) + { + DisneyBRDFPattern rs; + rs.baseColor = float3(0.5); + rs.opacity = 1.0; + return rs; + } +}; + +__generic<TVertexInterpolant:IVertexInterpolant, TPattern : IDisneyBRDFPattern> +float test(TVertexInterpolant vertInterps, TPattern pattern) +{ + CameraView cam; + VertexPosition vW, vO; + DisneyBRDFPattern rs = pattern.evalPattern(cam, vW, vO, vertInterps); + return rs.baseColor.x; +} +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + StandardVertexInterpolant vertInterp; + MaterialPattern1 mp1; + float outVal = test<StandardVertexInterpolant, MaterialPattern1>(vertInterp, mp1); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file diff --git a/tests/compute/generic-interface-method.slang.expected.txt b/tests/compute/generic-interface-method.slang.expected.txt new file mode 100644 index 000000000..e4e4c642a --- /dev/null +++ b/tests/compute/generic-interface-method.slang.expected.txt @@ -0,0 +1,4 @@ +3F000000 +3F000000 +3F000000 +3F000000 |
