diff options
| author | Yong He <yonghe@outlook.com> | 2017-11-03 09:38:02 -0400 |
|---|---|---|
| committer | Yong He <yonghe@outlook.com> | 2017-11-03 09:38:02 -0400 |
| commit | a0458266d7cd5d802b8c51e6a997b4bf0d9beb82 (patch) | |
| tree | 39f16538178907240e59b8e531ae153391805833 | |
| parent | d5e2319c33115d0241dd9d2047c0a5f029553dde (diff) | |
in-progress work
| -rw-r--r-- | source/core/core.natvis | 2 | ||||
| -rw-r--r-- | source/slang/bytecode.cpp | 56 | ||||
| -rw-r--r-- | source/slang/bytecode.h | 2 | ||||
| -rw-r--r-- | source/slang/check.cpp | 118 | ||||
| -rw-r--r-- | source/slang/decl-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 11 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 97 | ||||
| -rw-r--r-- | source/slang/lookup.cpp | 32 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 40 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 43 | ||||
| -rw-r--r-- | source/slang/mangle.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang.natvis | 27 | ||||
| -rw-r--r-- | source/slang/syntax-base-defs.h | 54 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 340 | ||||
| -rw-r--r-- | source/slang/type-defs.h | 55 | ||||
| -rw-r--r-- | source/slang/vm.cpp | 32 | ||||
| -rw-r--r-- | tests/compute/assoctype-complex.slang | 20 | ||||
| -rw-r--r-- | tests/compute/generics-constraint1.slang | 2 | ||||
| -rw-r--r-- | tests/compute/generics-constructor.slang | 17 |
19 files changed, 434 insertions, 522 deletions
diff --git a/source/core/core.natvis b/source/core/core.natvis index 3d9ac702e..91fdbb49b 100644 --- a/source/core/core.natvis +++ b/source/core/core.natvis @@ -54,7 +54,7 @@ </Expand> </Type> -<Type Name="Slang::RefPtrImpl<*,*,*>"> +<Type Name="Slang::RefPtr<*>"> <SmartPointer Usage="Minimal">pointer</SmartPointer> <DisplayString Condition="pointer == 0">empty</DisplayString> <DisplayString Condition="pointer != 0">RefPtr {*pointer}</DisplayString> diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp index f1b23849d..ee055a01f 100644 --- a/source/slang/bytecode.cpp +++ b/source/slang/bytecode.cpp @@ -77,7 +77,6 @@ struct BytecodeGenerationPtr BytecodeGenerationPtr<T> operator+(Int index) const { - UInt size = sizeof(T); Int delta = index * sizeof(T); UInt newOffset = offset + delta; return BytecodeGenerationPtr<T>( @@ -157,7 +156,7 @@ BCPtr<void>::RawVal allocateRaw( for(size_t ii = currentOffset; ii < endOffset; ++ii) context->shared->bytecode.Add(0); - return beginOffset; + return (BCPtr<void>::RawVal)beginOffset; } template<typename T> @@ -196,7 +195,7 @@ void encodeUInt( { if( value < 128 ) { - encodeUInt8(context, value); + encodeUInt8(context, (uint8_t)value); return; } @@ -256,9 +255,8 @@ BCConst getGlobalValue( UInt constID = context->shared->constants.Count(); context->shared->constants.Add(value); - BCConst bcConst; bcConst.flavor = kBCConstFlavor_Constant; - bcConst.id = constID; + bcConst.id = (uint32_t)constID; context->shared->mapValueToGlobal.Add(value, bcConst); @@ -270,10 +268,10 @@ BCConst getGlobalValue( break; } + bcConst.flavor = (uint32_t) -1; + bcConst.id = (uint32_t)-9999; SLANG_UNEXPECTED("no ID for inst"); - bcConst.flavor = (BCConstFlavor) -1; - bcConst.id = -9999; - return bcConst; + //return bcConst; } Int getLocalID( @@ -348,7 +346,6 @@ void generateBytecodeForInst( // auto argCount = inst->getArgCount(); - auto type = inst->getType(); encodeUInt(context, inst->op); encodeOperand(context, inst->getType()); encodeUInt(context, argCount); @@ -400,9 +397,9 @@ void generateBytecodeForInst( unsigned char buffer[size]; memcpy(buffer, &ii->u.floatVal, sizeof(buffer)); - for(UInt ii = 0; ii < size; ++ii) + for(UInt i = 0; i < size; ++i) { - encodeUInt8(context, buffer[ii]); + encodeUInt8(context, buffer[i]); } // destination: @@ -479,7 +476,7 @@ BytecodeGenerationPtr<BCType> emitBCType( auto bcArgs = (bcType + 1).bitCast<BCPtr<uint8_t>>(); bcType->op = op; - bcType->argCount = argCount; + bcType->argCount = (uint32_t)argCount; for(UInt aa = 0; aa < argCount; ++aa) { @@ -489,7 +486,7 @@ BytecodeGenerationPtr<BCType> emitBCType( UInt id = context->shared->bcTypes.Count(); context->shared->mapTypeToID.Add(type, id); context->shared->bcTypes.Add(bcType); - bcType->id = id; + bcType->id = (uint32_t)id; return bcType; } @@ -577,7 +574,6 @@ BytecodeGenerationPtr<BCType> emitBCTypeImpl( SLANG_UNEXPECTED("unimplemented"); - return BytecodeGenerationPtr<BCType>(); } BytecodeGenerationPtr<BCType> emitBCType( @@ -703,7 +699,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( // Allocate the array of block objects to be stored in the // bytecode file. auto bcBlocks = allocateArray<BCBlock>(context, blockCount); - bcFunc->blockCount = blockCount; + bcFunc->blockCount = (uint32_t)blockCount; bcFunc->blocks = bcBlocks; // Now loop through the blocks again, and allocate the storage @@ -750,7 +746,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( } } - bcBlocks[blockID].paramCount = paramCount; + bcBlocks[blockID].paramCount = (uint32_t)paramCount; } // Okay, we've counted how many registers we need for each block, @@ -758,7 +754,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( UInt regCount = regCounter; auto bcRegs = allocateArray<BCReg>(context, regCount); - bcFunc->regCount = regCount; + bcFunc->regCount = (uint32_t)regCount; bcFunc->regs = bcRegs; // Now we will loop over things again to fill in the information @@ -786,7 +782,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( #if 0 bcRegs[localID].name = tryGenerateNameForSymbol(context, pp); #endif - bcRegs[localID].previousVarIndexPlusOne = localID; + bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID; bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, pp); } @@ -808,7 +804,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( #if 0 bcRegs[localID].name = tryGenerateNameForSymbol(context, ii); #endif - bcRegs[localID].previousVarIndexPlusOne = localID; + bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID; bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii); } break; @@ -828,11 +824,11 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( #if 0 bcRegs[localID].name = tryGenerateNameForSymbol(context, ii); #endif - bcRegs[localID].previousVarIndexPlusOne = localID; + bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID; bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii); bcRegs[localID+1].op = ii->op; - bcRegs[localID+1].previousVarIndexPlusOne = localID+1; + bcRegs[localID+1].previousVarIndexPlusOne = (uint32_t)localID+1; bcRegs[localID+1].typeID = getTypeID(context, (ii->getType()->As<PtrType>())->getValueType()); } @@ -840,7 +836,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( } } } - assert(regCounter == regCount); + assert((UInt)regCounter == regCount); // Now that we've allocated our blocks and our registers // we can go through the actual process of emitting instructions. Hooray! @@ -851,7 +847,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( List<UInt> blockOffsets; for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() ) { - UInt blockID = blockCounter++; + blockCounter++; // Get local bytecode offset for current block. UInt blockOffset = subContext->currentBytecode.Count(); @@ -903,7 +899,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst( UInt constCount = subContext->remappedGlobalSymbols.Count(); auto bcConsts = allocateArray<BCConst>(context, constCount); - bcFunc->constCount = constCount; + bcFunc->constCount = (uint32_t)constCount; bcFunc->consts = bcConsts; for( UInt cc = 0; cc < constCount; ++cc ) @@ -969,7 +965,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule( // Ensure that local code inside functions can see these symbols BCConst bcConst; bcConst.flavor = kBCConstFlavor_GlobalSymbol; - bcConst.id = globalID; + bcConst.id = (uint32_t)globalID; context->shared->mapValueToGlobal.Add(gv, bcConst); // In the global scope, global IDs are also the local IDs @@ -978,7 +974,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule( auto bcSymbols = allocateArray<BCPtr<BCSymbol>>(context, symbolCount); - bcModule->symbolCount = symbolCount; + bcModule->symbolCount = (uint32_t)symbolCount; bcModule->symbols = bcSymbols; for( auto gv = irModule->getFirstGlobalValue(); gv; gv = gv->getNextValue() ) @@ -998,7 +994,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule( // At this point we should have identified all the literals we need: UInt constantCount = context->shared->constants.Count(); auto bcConstants = allocateArray<BCConstant>(context, constantCount); - bcModule->constantCount = constantCount; + bcModule->constantCount = (uint32_t)constantCount; bcModule->constants = bcConstants; for(UInt cc = 0; cc < constantCount; ++cc) @@ -1026,7 +1022,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule( // At this point we should have collected all the types we need: UInt typeCount = context->shared->bcTypes.Count(); auto bcTypes = allocateArray<BCPtr<BCType>>(context, typeCount); - bcModule->typeCount = typeCount; + bcModule->typeCount = (uint32_t)typeCount; bcModule->types = bcTypes; for(UInt tt = 0; tt < typeCount; ++tt) @@ -1055,8 +1051,6 @@ void generateBytecodeContainer( // TODO: Need to dump BC representation of compiled kernel codes // for each specified code-generation target. - UInt translationUnitCount = compileReq->translationUnits.Count(); - List<BytecodeGenerationPtr<BCModule>> bcModulesList; for (auto translationUnitReq : compileReq->translationUnits) { @@ -1065,7 +1059,7 @@ void generateBytecodeContainer( } UInt bcModuleCount = bcModulesList.Count(); - header->moduleCount = bcModuleCount; + header->moduleCount = (uint32_t)bcModuleCount; auto bcModules = allocateArray<BCPtr<BCModule>>(context, bcModuleCount); header->modules = bcModules; diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h index 75b9f15cd..f1ad52c32 100644 --- a/source/slang/bytecode.h +++ b/source/slang/bytecode.h @@ -56,7 +56,7 @@ struct BCPtr { if (ptr) { - rawVal = (char*)ptr - (char*)this; + rawVal = (RawVal)((char*)ptr - (char*)this); } else { diff --git a/source/slang/check.cpp b/source/slang/check.cpp index dc4e48545..137b1c451 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -148,15 +148,13 @@ namespace Slang if (baseExpr) { RefPtr<Expr> expr; - + if (baseExpr->type->As<TypeType>()) { auto sexpr = new StaticMemberExpr(); sexpr->loc = loc; sexpr->BaseExpression = baseExpr; sexpr->name = declRef.GetName(); - sexpr->type = GetTypeForDeclRef(declRef); - sexpr->declRef = declRef; expr = sexpr; } else @@ -165,54 +163,24 @@ namespace Slang sexpr->loc = loc; sexpr->BaseExpression = baseExpr; sexpr->name = declRef.GetName(); - sexpr->type = GetTypeForDeclRef(declRef); sexpr->declRef = declRef; expr = sexpr; } - if (auto constraintType = expr->type->As<GenericConstraintDeclRefType>()) + if (auto assocTypeDecl = declRef.As<AssocTypeDecl>()) { - if (baseExpr->type->As<TypeType>()) - constraintType->subType = baseExpr->type->As<TypeType>()->type; - else - constraintType->subType = baseExpr->type; - + RefPtr<ThisTypeSubstitution> subst = new ThisTypeSubstitution(); + subst->sourceType = baseExpr->type.type; + if (auto typeType = subst->sourceType.As<TypeType>()) + subst->sourceType = typeType->type; + expr->type = GetTypeForDeclRef(DeclRef<AssocTypeDecl>(assocTypeDecl.getDecl(), subst)); } - - if (auto genConstraintType = baseExpr->type->As<GenericConstraintDeclRefType>()) + else if (auto constraintDecl = declRef.As<GenericTypeConstraintDecl>()) { - if (auto funcDeclRef = declRef.As<CallableDecl>()) - { - // if this is call expression, propagate the source associated type to the result type - auto funcType = expr->type->As<FuncType>(); - if (auto assocRsType = funcType->resultType.As<AssocTypeDeclRefType>()) - { - RefPtr<FuncType> newFuncType = new FuncType(); - newFuncType->paramTypes = funcType->paramTypes; - RefPtr<AssocTypeDeclRefType> newRsType = new AssocTypeDeclRefType(); - newRsType->declRef = assocRsType->declRef; - newRsType->sourceType = genConstraintType->subType; - newRsType->setSession(getSession()); - newFuncType->resultType = newRsType; - newFuncType->setSession(funcType->getSession()); - expr->type = QualType(newFuncType); - } - } - else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) - { - auto assocTypeDeclType = new AssocTypeDeclRefType(); - assocTypeDeclType->declRef = assocTypeDeclRef; - assocTypeDeclType->sourceType = genConstraintType->subType; - assocTypeDeclType->setSession(getSession()); - expr->type = QualType(getTypeType(assocTypeDeclType)); - } + expr->type = baseExpr->type; } - else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) + else { - auto assocTypeDeclType = new AssocTypeDeclRefType(); - assocTypeDeclType->declRef = assocTypeDeclRef; - assocTypeDeclType->sourceType = baseExpr->type; - assocTypeDeclType->setSession(getSession()); - expr->type = QualType(getTypeType(assocTypeDeclType)); + expr->type = GetTypeForDeclRef(declRef); } return expr; } @@ -449,7 +417,7 @@ namespace Slang DeclRef<GenericDecl> genericDeclRef, List<RefPtr<Expr>> const& args) { - RefPtr<Substitutions> subst = new Substitutions(); + RefPtr<GenericSubstitution> subst = new GenericSubstitution(); subst->genericDecl = genericDeclRef.getDecl(); subst->outer = genericDeclRef.substitutions; @@ -2097,10 +2065,10 @@ namespace Slang return true; } - RefPtr<Substitutions> createDummySubstitutions( + RefPtr<GenericSubstitution> createDummySubstitutions( GenericDecl* genericDecl) { - RefPtr<Substitutions> subst = new Substitutions(); + RefPtr<GenericSubstitution> subst = new GenericSubstitution(); subst->genericDecl = genericDecl; for (auto dd : genericDecl->Members) { @@ -3115,7 +3083,7 @@ namespace Slang session, "Vector").As<GenericDecl>(); auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = new Substitutions(); + auto substitutions = new GenericSubstitution(); substitutions->genericDecl = vectorGenericDecl.Ptr(); substitutions->args.Add(elementType); substitutions->args.Add(elementCount); @@ -3447,7 +3415,7 @@ namespace Slang // to `interfaceDeclRef`. // SLANG_UNEXPECTED("reflexive type witness"); - return nullptr; + //return nullptr; } auto breadcrumbs = inBreadcrumbs; @@ -3462,7 +3430,7 @@ namespace Slang // because `A : B` and `B : C` then `A : C` // SLANG_UNEXPECTED("transitive type witness"); - return nullptr; + //return nullptr; } // Simple case: we have a single declaration @@ -3815,7 +3783,7 @@ namespace Slang // Consruct a reference to the extension with our constraint variables // as the - RefPtr<Substitutions> solvedSubst = new Substitutions(); + RefPtr<GenericSubstitution> solvedSubst = new GenericSubstitution(); solvedSubst->genericDecl = genericDeclRef.getDecl(); solvedSubst->outer = genericDeclRef.substitutions; solvedSubst->args = args; @@ -4084,8 +4052,9 @@ namespace Slang // We will go ahead and hang onto the arguments that we've // already checked, since downstream validation might need // them. - candidate.subst = new Substitutions(); - auto& checkedArgs = candidate.subst->args; + auto genSubst = new GenericSubstitution(); + candidate.subst = genSubst; + auto& checkedArgs = genSubst->args; int aa = 0; for (auto memberRef : getMembers(genericDeclRef)) @@ -4202,7 +4171,7 @@ namespace Slang // Create a witness that attests to the fact that `type` // is equal to itself. RefPtr<Val> createTypeEqualityWitness( - Type* type) + Type* /*type*/) { SLANG_UNEXPECTED("unimplemented"); } @@ -4258,7 +4227,7 @@ namespace Slang // We should have the existing arguments to the generic // handy, so that we can construct a substitution list. - RefPtr<Substitutions> subst = candidate.subst; + RefPtr<GenericSubstitution> subst = candidate.subst.As<GenericSubstitution>(); assert(subst); subst->genericDecl = genericDeclRef.getDecl(); @@ -4325,7 +4294,7 @@ namespace Slang RefPtr<Expr> createGenericDeclRef( RefPtr<Expr> baseExpr, RefPtr<Expr> originalExpr, - RefPtr<Substitutions> subst) + RefPtr<GenericSubstitution> subst) { auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>(); if (!baseDeclRefExpr) @@ -4437,7 +4406,7 @@ namespace Slang return createGenericDeclRef( baseExpr, context.originalExpr, - candidate.subst); + candidate.subst.As<GenericSubstitution>()); break; default: @@ -4734,22 +4703,23 @@ namespace Slang // They must both be NULL or non-NULL if (!fst || !snd) return fst == snd; - + auto fstGen = fst.As<GenericSubstitution>(); + auto sndGen = snd.As<GenericSubstitution>(); // They must be specializing the same generic - if (fst->genericDecl != snd->genericDecl) + if (fstGen->genericDecl != sndGen->genericDecl) return false; // Their arguments must unify - SLANG_RELEASE_ASSERT(fst->args.Count() == snd->args.Count()); - UInt argCount = fst->args.Count(); + SLANG_RELEASE_ASSERT(fstGen->args.Count() == sndGen->args.Count()); + UInt argCount = fstGen->args.Count(); for (UInt aa = 0; aa < argCount; ++aa) { - if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa])) + if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa])) return false; } // Their "base" specializations must unify - if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer)) + if (!TryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer)) return false; return true; @@ -5304,11 +5274,12 @@ namespace Slang if( parentGenericDeclRef ) { SLANG_RELEASE_ASSERT(declRef.substitutions); - SLANG_RELEASE_ASSERT(declRef.substitutions->genericDecl == parentGenericDeclRef.getDecl()); + auto genSubst = declRef.substitutions.As<GenericSubstitution>(); + SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl()); sb << "<"; bool first = true; - for(auto arg : declRef.substitutions->args) + for(auto arg : genSubst->args) { if(!first) sb << ", "; formatVal(sb, arg); @@ -6069,7 +6040,7 @@ namespace Slang RefPtr<Expr> visitStaticMemberExpr(StaticMemberExpr* expr) { SLANG_UNEXPECTED("should not occur in unchecked AST"); - return expr; + //return expr; } RefPtr<Expr> lookupResultFailure( @@ -6495,26 +6466,11 @@ namespace Slang *outTypeResult = type; return QualType(getTypeType(type)); } - else if (auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>()) - { - // When we access a constraint or an inheritance decl (as a member), - // we are conceptually performing a "cast" to the given super-type, - // with the declaration showing that such a cast is legal. - auto type = new GenericConstraintDeclRefType(session, GetSub(constraintDeclRef), GetSup(constraintDeclRef)); - return QualType(type); - } else if (auto funcDeclRef = declRef.As<CallableDecl>()) { auto type = getFuncType(session, funcDeclRef); return QualType(type); } - else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) - { - auto type = new AssocTypeDeclRefType(assocTypeDeclRef); - type->setSession(session); - *outTypeResult = type; - return QualType(getTypeType(type)); - } if( sink ) { sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration"); @@ -6558,7 +6514,7 @@ namespace Slang if(decl != genericDecl->inner) return parentSubst; - RefPtr<Substitutions> subst = new Substitutions(); + RefPtr<GenericSubstitution> subst = new GenericSubstitution(); subst->genericDecl = genericDecl; subst->outer = parentSubst; diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index bb8c26f58..9c010d156 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -123,7 +123,7 @@ SYNTAX_CLASS(TypeDefDecl, SimpleTypeDecl) END_SYNTAX_CLASS() // An 'assoctype' declaration, it is a container of inheritance clauses -SYNTAX_CLASS(AssocTypeDecl, ContainerDecl) +SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) END_SYNTAX_CLASS() // A scope for local declarations (e.g., as part of a statement) diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 9caad2e1a..2627e1f37 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -1052,15 +1052,6 @@ struct EmitVisitor EmitDeclarator(arg.declarator); } - void visitAssocTypeDeclRefType(AssocTypeDeclRefType* /*type*/, TypeEmitArg const& /*arg*/) - { - //SLANG_UNREACHABLE("visitAssocTypeDeclRefType in EmitVisitor"); - } - void visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* /*type*/, TypeEmitArg const& /*arg*/) - { - //SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in EmitVisitor"); - } - void visitBasicExpressionType(BasicExpressionType* basicType, TypeEmitArg const& arg) { auto declarator = arg.declarator; @@ -2925,7 +2916,7 @@ struct EmitVisitor return; } - Substitutions* subst = declRef.substitutions.Ptr(); + GenericSubstitution* subst = declRef.substitutions.As<GenericSubstitution>().Ptr(); if (!subst) return; diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 84ab3d9a5..5f3ca6c62 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -178,27 +178,27 @@ namespace Slang // // Add an instruction to a specific parent - void IRBuilder::addInst(IRBlock* block, IRInst* inst) + void IRBuilder::addInst(IRBlock* pblock, IRInst* inst) { - inst->parent = block; + inst->parent = pblock; - if (!block->firstInst) + if (!pblock->firstInst) { inst->prev = nullptr; inst->next = nullptr; - block->firstInst = inst; - block->lastInst = inst; + pblock->firstInst = inst; + pblock->lastInst = inst; } else { - auto prev = block->lastInst; + auto prev = pblock->lastInst; inst->prev = prev; inst->next = nullptr; prev->next = inst; - block->lastInst = inst; + pblock->lastInst = inst; } } @@ -206,7 +206,6 @@ namespace Slang void IRBuilder::addInst( IRInst* inst) { - auto insertBefore = insertBeforeInst; if(insertBeforeInst) { inst->insertBefore(insertBeforeInst); @@ -221,7 +220,7 @@ namespace Slang } static IRValue* createValueImpl( - IRBuilder* builder, + IRBuilder* /*builder*/, UInt size, IROp op, IRType* type) @@ -249,7 +248,7 @@ namespace Slang // arguments *after* the type (which is a mandatory // argument for all instructions). static IRInst* createInstImpl( - IRBuilder* builder, + IRBuilder* /*builder*/, UInt size, IROp op, IRType* type, @@ -261,8 +260,7 @@ namespace Slang IRInst* inst = (IRInst*) malloc(size); memset(inst, 0, size); - auto module = builder->getModule(); - inst->argCount = fixedArgCount + varArgCount; + inst->argCount = (uint32_t)(fixedArgCount + varArgCount); inst->op = op; @@ -632,7 +630,7 @@ namespace Slang IRInst* IRBuilder::emitCallInst( IRType* type, - IRValue* func, + IRValue* pFunc, UInt argCount, IRValue* const* args) { @@ -641,7 +639,7 @@ namespace Slang kIROp_Call, type, 1, - &func, + &pFunc, argCount, args); addInst(inst); @@ -829,12 +827,12 @@ namespace Slang IRFunc* IRBuilder::createFunc() { - IRFunc* func = createValue<IRFunc>( + IRFunc* rsFunc = createValue<IRFunc>( this, kIROp_Func, nullptr); - addGlobalValue(getModule(), func); - return func; + addGlobalValue(getModule(), rsFunc); + return rsFunc; } IRGlobalVar* IRBuilder::createGlobalVar( @@ -1129,13 +1127,13 @@ namespace Slang } IRInst* IRBuilder::emitBranch( - IRBlock* block) + IRBlock* pBlock) { auto inst = createInst<IRUnconditionalBranch>( this, kIROp_unconditionalBranch, nullptr, - block); + pBlock); addInst(inst); return inst; } @@ -1543,7 +1541,7 @@ namespace Slang if(genericParentDeclRef) { - auto subst = declRef.substitutions; + auto subst = declRef.substitutions.As<GenericSubstitution>(); if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() ) { // No actual substitutions in place here @@ -1698,6 +1696,7 @@ namespace Slang dumpChildrenRaw(context, block); } +#if 0 static void dumpChildrenRaw( IRDumpContext* context, IRFunc* func) @@ -1720,6 +1719,7 @@ namespace Slang dumpIndent(context); dump(context, "}\n"); } +#endif static void dumpInst( IRDumpContext* context, @@ -2239,8 +2239,8 @@ namespace Slang void IRInst::removeArguments() { - UInt argCount = this->argCount; - for( UInt aa = 0; aa < argCount; ++aa ) + UInt oldArgCount = this->argCount; + for( UInt aa = 0; aa < oldArgCount; ++aa ) { IRUse& use = getArgs()[aa]; @@ -2384,9 +2384,9 @@ namespace Slang IRBuilder* builder, Type* type, VarLayout* varLayout, - TypeLayout* typeLayout, - LayoutResourceKind kind, - GlobalVaryingDeclarator* declarator) + TypeLayout* /*typeLayout*/, + LayoutResourceKind /*kind*/, + GlobalVaryingDeclarator* /*declarator*/) { // TODO: We might be creating an `in` or `out` variable based on // an `in out` function parameter. In this case we should @@ -2550,7 +2550,6 @@ namespace Slang default: SLANG_UNEXPECTED("unimplemented"); - return ScalarizedVal(); } } @@ -3079,7 +3078,6 @@ namespace Slang break; default: SLANG_UNEXPECTED("no value registered for IR value"); - return nullptr; } } @@ -3111,18 +3109,27 @@ namespace Slang { if (!subst) return nullptr; + if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst)) + { + RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); + newSubst->outer = cloneSubstitutions(context, subst->outer); + newSubst->genericDecl = genSubst->genericDecl; - RefPtr<Substitutions> newSubst = new Substitutions(); - newSubst->outer = cloneSubstitutions(context, subst->outer); - newSubst->genericDecl = subst->genericDecl; - - for (auto arg : subst->args) + for (auto arg : genSubst->args) + { + auto newArg = cloneSubstitutionArg(context, arg); + newSubst->args.Add(arg); + } + return newSubst; + } + else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst)) { - auto newArg = cloneSubstitutionArg(context, arg); - newSubst->args.Add(arg); + RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution(); + newSubst->sourceType = thisSubst->sourceType; + newSubst->outer = cloneSubstitutions(context, subst->outer); + return newSubst; } - - return newSubst; + return nullptr; } DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef) @@ -3231,7 +3238,7 @@ namespace Slang { auto clonedKey = context->maybeCloneValue(originalEntry->requirementKey.usedValue); auto clonedVal = context->maybeCloneValue(originalEntry->satisfyingVal.usedValue); - auto clonedEntry = context->builder->createWitnessTableEntry( + context->builder->createWitnessTableEntry( clonedTable, clonedKey, clonedVal); @@ -3416,7 +3423,6 @@ namespace Slang default: SLANG_UNEXPECTED("unhandled case"); - return "unknown"; } } @@ -3518,7 +3524,6 @@ namespace Slang { // This shouldn't happen! SLANG_UNEXPECTED("no matching function registered"); - return cloneSimpleFunc(context, originalFunc); } // We will try to track the "best" definition we can find. @@ -3748,7 +3753,6 @@ namespace Slang else { SLANG_UNEXPECTED("unimplemented"); - return nullptr; } } @@ -3756,7 +3760,8 @@ namespace Slang IRGenericSpecContext* context, DeclRef<Decl> declRef) { - auto subst = context->subst; + auto subst = context->subst.As<GenericSubstitution>(); + SLANG_ASSERT(subst); auto genericDecl = subst->genericDecl; UInt orinaryParamCount = 0; @@ -3806,12 +3811,13 @@ namespace Slang { auto declRefVal = (IRDeclRef*) originalVal; auto declRef = declRefVal->declRef; - + auto genSubst = subst.As<GenericSubstitution>(); + SLANG_ASSERT(genSubst); // We may have a direct reference to one of the parameters // of the generic we are specializing, and in that case // we nee to translate it over to the equiavalent of // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == subst->genericDecl) + if(declRef.getDecl()->ParentDecl == genSubst->genericDecl) { return getSubstValue(this, declRef); } @@ -3866,9 +3872,10 @@ namespace Slang // using a different overload of a target-specific function, // so we need to create a dummy substitution here, to make // sure it used the correct generic. - RefPtr<Substitutions> newSubst = new Substitutions(); + RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); newSubst->genericDecl = genericFunc->genericDecl; - newSubst->args = specDeclRef.substitutions->args; + auto specDeclRefSubst = specDeclRef.substitutions.As<GenericSubstitution>(); + newSubst->args = specDeclRefSubst->args; IRGenericSpecContext context; context.shared = sharedContext; diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index b97dac560..a5425074f 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -452,28 +452,28 @@ void lookUpMemberImpl( lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb); } } - } - else if (auto assocTypeDeclRefType = type->As<AssocTypeDeclRefType>()) - { - auto assocTypeDeclRef = assocTypeDeclRefType->declRef; - for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef)) + else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) { - // The super-type in the constraint (e.g., `Foo` in `T : Foo`) - // will tell us a type we should use for lookup. - auto bound = GetSup(constraintDeclRef); + for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef)) + { + // The super-type in the constraint (e.g., `Foo` in `T : Foo`) + // will tell us a type we should use for lookup. + auto bound = GetSup(constraintDeclRef); - // Go ahead and use the target type, with an appropriate breadcrumb - // to indicate that we indirected through a type constraint. + // Go ahead and use the target type, with an appropriate breadcrumb + // to indicate that we indirected through a type constraint. - BreadcrumbInfo breadcrumb; - breadcrumb.prev = inBreadcrumbs; - breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; - breadcrumb.declRef = constraintDeclRef; + BreadcrumbInfo breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint; + breadcrumb.declRef = constraintDeclRef; - // TODO: Need to consider case where this might recurse infinitely. - lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb); + // TODO: Need to consider case where this might recurse infinitely. + lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb); + } } } + } LookupResult lookUpMember( diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 998197279..6099df1ed 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -870,9 +870,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower auto subs = declRef.substitutions; while(subs) { - for(auto aa : subs->args) + if (auto genSubst = subs.As<GenericSubstitution>()) { - (*ioArgs).Add(getSimpleVal(context, lowerVal(context, aa))); + for (auto aa : genSubst->args) + { + (*ioArgs).Add(getSimpleVal(context, lowerVal(context, aa))); + } } subs = subs->outer; } @@ -3037,24 +3040,33 @@ RefPtr<Substitutions> lowerSubstitutions( { if(!subst) return nullptr; + RefPtr<Substitutions> result; + if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst)) + { + RefPtr<GenericSubstitution> newSubst = new GenericSubstitution(); + newSubst->genericDecl = genSubst->genericDecl; + + for (auto arg : genSubst->args) + { + auto newArg = lowerSubstitutionArg(context, arg); + newSubst->args.Add(newArg); + } - RefPtr<Substitutions> newSubst = new Substitutions(); + result = newSubst; + } + else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst)) + { + RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution(); + newSubst->sourceType = lowerSubstitutionArg(context, thisSubst->sourceType); + result = newSubst; + } if (subst->outer) { - newSubst->outer = lowerSubstitutions( + result->outer = lowerSubstitutions( context, subst->outer); } - - newSubst->genericDecl = subst->genericDecl; - - for (auto arg : subst->args) - { - auto newArg = lowerSubstitutionArg(context, arg); - newSubst->args.Add(newArg); - } - - return newSubst; + return result; } LoweredValInfo emitDeclRef( diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index b810d9643..b8696df47 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -779,20 +779,6 @@ struct LoweringVisitor translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>()); } - RefPtr<Type> visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* type) - { - // not supported by lowering - SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in LowerVisitor"); - return nullptr; - } - - RefPtr<Type> visitAssocTypeDeclRefType(AssocTypeDeclRefType* type) - { - // not supported by lowering - SLANG_UNREACHABLE("visitAssocTypeDeclRefType in LowerVisitor"); - return nullptr; - } - RefPtr<Type> visitTypeType(TypeType* type) { return getTypeType(lowerType(type->type)); @@ -2569,14 +2555,23 @@ struct LoweringVisitor Substitutions* inSubstitutions) { if (!inSubstitutions) return nullptr; - - RefPtr<Substitutions> result = new Substitutions(); - result->genericDecl = translateDeclRef(inSubstitutions->genericDecl).As<GenericDecl>(); - for (auto arg : inSubstitutions->args) + if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions)) { - result->args.Add(translateVal(arg)); + RefPtr<GenericSubstitution> result = new GenericSubstitution(); + result->genericDecl = translateDeclRef(genSubst->genericDecl).As<GenericDecl>(); + for (auto arg : genSubst->args) + { + result->args.Add(translateVal(arg)); + } + return result; } - return result; + else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(inSubstitutions)) + { + RefPtr<ThisTypeSubstitution> result = new ThisTypeSubstitution(); + result->sourceType = translateVal(result->sourceType); + return result; + } + return nullptr; } static Decl* getModifiedDecl(Decl* decl) @@ -2733,7 +2728,11 @@ struct LoweringVisitor RefPtr<VarLayout> tryToFindLayout( Decl* decl) { - auto loweredParent = translateDeclRef(decl->ParentDecl); + RefPtr<Decl> loweredParent; + if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>()) + loweredParent = translateDeclRef(genericParentDecl->ParentDecl); + else + loweredParent = translateDeclRef(decl->ParentDecl); if (loweredParent) { auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>(); @@ -3831,7 +3830,7 @@ struct LoweringVisitor "Vector").As<GenericDecl>(); auto vectorTypeDecl = vectorGenericDecl->inner; - auto substitutions = new Substitutions(); + auto substitutions = new GenericSubstitution(); substitutions->genericDecl = vectorGenericDecl.Ptr(); substitutions->args.Add(elementType); substitutions->args.Add(elementCount); diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp index cd90a0e41..bde4e12c7 100644 --- a/source/slang/mangle.cpp +++ b/source/slang/mangle.cpp @@ -117,10 +117,6 @@ namespace Slang { emitQualifiedName(context, declRefType->declRef); } - else if (auto assocTypeDeclRefType = dynamic_cast<AssocTypeDeclRefType*>(type)) - { - emitQualifiedName(context, assocTypeDeclRefType->declRef); - } else { SLANG_UNEXPECTED("unimplemented case in mangling"); @@ -199,7 +195,7 @@ namespace Slang // There are two cases here: either we have specializations // in place for the parent generic declaration, or we don't. - auto subst = declRef.substitutions; + auto subst = declRef.substitutions.As<GenericSubstitution>(); if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() ) { // This is the case where we *do* have substitutions. diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 7a7f7fe0e..7e6fd3753 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -9,5 +9,30 @@ <ExpandedItem>rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0</ExpandedItem> </Expand> </Type> - + <Type Name="Slang::DeclRef<*>"> + <SmartPointer Usage="Minimal">decl ? ($T1*)(decl) : ($T1*)0</SmartPointer> + <DisplayString Condition="decl == 0">DeclRef nullptr</DisplayString> + <DisplayString Condition="decl != 0">DeclRef {(*(*(Slang::DeclRefBase*)this).decl).nameAndLoc}</DisplayString> + <Expand> + <ExpandedItem>decl ? ($T1*)(decl) : ($T1*)0</ExpandedItem> + <Item Name="[Substitutions]:">"========================="</Item> + <LinkedListItems> + <HeadPointer>substitutions.pointer</HeadPointer> + <NextPointer>outer.pointer</NextPointer> + <ValueNode>this</ValueNode> + </LinkedListItems> + </Expand> + </Type> + <Type Name="Slang::DeclRefType"> + <DisplayString>DeclRefType {declRef}</DisplayString> + <Expand> + <ExpandedItem>declRef</ExpandedItem> + </Expand> + </Type> + <Type Name="Slang::Name"> + <DisplayString>{{name={(char*)(text.buffer.pointer+1), s}}}</DisplayString> + </Type> + <Type Name="Slang::NameLoc"> + <DisplayString>{{name={(char*)((*name).text.buffer.pointer+1), s} loc={loc.raw}}}</DisplayString> + </Type> </AutoVisualizer>
\ No newline at end of file diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 8a22a61d2..3c7e8c5ae 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -126,31 +126,44 @@ protected: ) END_SYNTAX_CLASS() + // A substitution represents a binding of certain // type-level variables to concrete argument values -SYNTAX_CLASS(Substitutions, RefObject) +ABSTRACT_SYNTAX_CLASS(Substitutions, RefObject) + + // Any further substitutions, relating to outer generic declarations + SYNTAX_FIELD(RefPtr<Substitutions>, outer) + + RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) = 0; + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) = 0; + virtual bool operator == (const Substitutions & subst) = 0; + virtual int GetHashCode() const = 0; + ) +END_SYNTAX_CLASS() + +SYNTAX_CLASS(GenericSubstitution, Substitutions) // The generic declaration that defines the // parametesr we are binding to arguments - DECL_FIELD(GenericDecl*, genericDecl) + DECL_FIELD(GenericDecl*, genericDecl) // The actual values of the arguments SYNTAX_FIELD(List<RefPtr<Val>>, args) - - // Any further substitutions, relating to outer generic declarations - SYNTAX_FIELD(RefPtr<Substitutions>, outer) - + RAW( // Apply a set of substitutions to the bindings in this substitution - RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff); + virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override; // Check if these are equivalent substitutiosn to another set - bool Equals(Substitutions* subst); - bool operator == (const Substitutions & subst) + virtual bool Equals(Substitutions* subst) override; + virtual bool operator == (const Substitutions & subst) override { return Equals(const_cast<Substitutions*>(&subst)); } - int GetHashCode() const + virtual int GetHashCode() const override { int rs = 0; for (auto && v : args) @@ -163,6 +176,27 @@ SYNTAX_CLASS(Substitutions, RefObject) ) END_SYNTAX_CLASS() +SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) + // The actual type that provides the lookup scope for an associated type + SYNTAX_FIELD(RefPtr<Val>, sourceType) + + RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) override; + virtual bool operator == (const Substitutions & subst) override + { + return Equals(const_cast<Substitutions*>(&subst)); + } + virtual int GetHashCode() const override + { + return sourceType->GetHashCode(); + } + ) +END_SYNTAX_CLASS() + ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase) END_SYNTAX_CLASS() diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 3e38955ba..a3a5fdcb6 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -91,6 +91,8 @@ ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNodeBase); ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode); ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode); +ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions); +ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions); #include "expr-defs.h" #include "decl-defs.h" @@ -98,8 +100,6 @@ ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode); #include "stmt-defs.h" #include "type-defs.h" #include "val-defs.h" - - #include "object-meta-end.h" bool SyntaxClassBase::isSubClassOfImpl(SyntaxClassBase const& super) const @@ -283,7 +283,7 @@ void Type::accept(IValVisitor* visitor, void* extra) this, "PtrType").As<GenericDecl>(); auto typeDecl = genericDecl->inner; - auto substitutions = new Substitutions(); + auto substitutions = new GenericSubstitution(); substitutions->genericDecl = genericDecl.Ptr(); substitutions->args.Add(valueType); @@ -414,38 +414,57 @@ void Type::accept(IValVisitor* visitor, void* extra) // search for a substitution that might apply to us for (auto s = subst; s; s = s->outer.Ptr()) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != genericTypeParamDecl->ParentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->Members) + if (auto genericSubst = dynamic_cast<GenericSubstitution*>(s)) { - if (m.Ptr() == genericTypeParamDecl) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return s->args[index]; - } - else if(auto typeParam = m.As<GenericTypeParamDecl>()) + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genericSubst->genericDecl; + if (genericDecl != genericTypeParamDecl->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) { - index++; - } - else if(auto valParam = m.As<GenericValueParamDecl>()) - { - index++; + if (m.Ptr() == genericTypeParamDecl) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genericSubst->args[index]; + } + else if (auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else + { + } } - else + } + + } + } + // the second case we care about is when this decl type refers to an associatedtype decl + // we want to replace it with the actual associated type + else if (auto assocTypeDecl = dynamic_cast<AssocTypeDecl*>(declRef.getDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) + { + if (auto thisTypeSubst = dynamic_cast<ThisTypeSubstitution*>(s)) + { + if (auto aggTypeDeclRef = thisTypeSubst->sourceType.As<DeclRefType>()->declRef.As<AggTypeDecl>()) { + Decl * targetType = nullptr; + if (aggTypeDeclRef.getDecl()->memberDictionary.TryGetValue(assocTypeDecl->getName(), targetType)) + return DeclRefType::Create(this->getSession(), DeclRef<Decl>(targetType, aggTypeDeclRef.substitutions)); } } - } } - - int diff = 0; DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff); @@ -486,10 +505,25 @@ void Type::accept(IValVisitor* visitor, void* extra) // we will construct a default specialization at the use // site if needed. - if( auto genericParent = declRef.GetParent().As<GenericDecl>() ) + if (auto genericParent = declRef.GetParent().As<GenericDecl>()) { auto subst = declRef.substitutions; - if( !subst || subst->genericDecl != genericParent.decl ) + // try find a substitution targeting this generic decl + bool substFound = false; + while (subst) + { + if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst.Ptr())) + { + if (genSubst->genericDecl == genericParent.decl) + { + substFound = true; + break; + } + } + subst = subst->outer; + } + // we did not find an existing substituion, create a default one + if (!substFound) { declRef.substitutions = createDefaultSubstitutions( session, @@ -507,7 +541,7 @@ void Type::accept(IValVisitor* visitor, void* extra) } else if (auto magicMod = declRef.getDecl()->FindModifier<MagicTypeModifier>()) { - Substitutions* subst = declRef.substitutions.Ptr(); + GenericSubstitution* subst = declRef.substitutions.As<GenericSubstitution>().Ptr(); if (magicMod->name == "SamplerState") { @@ -761,7 +795,6 @@ void Type::accept(IValVisitor* visitor, void* extra) bool NamedExpressionType::EqualsImpl(Type * /*type*/) { SLANG_UNEXPECTED("unreachable"); - return false; } Type* NamedExpressionType::CreateCanonicalType() @@ -772,7 +805,6 @@ void Type::accept(IValVisitor* visitor, void* extra) int NamedExpressionType::GetHashCode() { SLANG_UNEXPECTED("unreachable"); - return 0; } // FuncType @@ -910,7 +942,6 @@ void Type::accept(IValVisitor* visitor, void* extra) int TypeType::GetHashCode() { SLANG_UNEXPECTED("unreachable"); - return 0; } // GenericDeclRefType @@ -940,125 +971,6 @@ void Type::accept(IValVisitor* visitor, void* extra) return this; } - // AssocTypeDeclRefType - - String AssocTypeDeclRefType::ToString() - { - // TODO: what is appropriate here? - return "<AssocType>"; - } - - bool AssocTypeDeclRefType::EqualsImpl(Type * type) - { - if (auto assocTypeDeclRefType = type->As<AssocTypeDeclRefType>()) - { - return declRef.Equals(assocTypeDeclRefType->declRef); - } - return false; - } - - RefPtr<Val> AssocTypeDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - if (!sourceType) - return this; - auto substSourceType = sourceType->SubstituteImpl(subst, ioDiff); - if (auto parentDeclRefType = substSourceType.As<DeclRefType>()) - { - auto parentDeclRef = parentDeclRefType->declRef; - DeclRef<AggTypeDecl> newParentDeclRef = parentDeclRef.As<AggTypeDecl>(); - // search for a substitution that might apply to us - for (auto s = subst; s; s = s->outer.Ptr()) - { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != parentDeclRef.getDecl()->ParentDecl) - continue; - int index = 0; - for (auto m : genericDecl->Members) - { - if (m.Ptr() == parentDeclRef.getDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - if (auto declRef = s->args[index].As<DeclRefType>()) - { - newParentDeclRef = (*declRef).declRef.As<AggTypeDecl>(); - goto searchEnd; - } - } - else if (auto typeParam = m.As<GenericTypeParamDecl>()) - { - index++; - } - else if (auto valParam = m.As<GenericValueParamDecl>()) - { - index++; - } - else - { - } - } - } - searchEnd: - if (newParentDeclRef) - { - Decl* targetTypeDecl = nullptr; - if (newParentDeclRef.getDecl()->memberDictionary.TryGetValue(this->GetDeclRef().decl->getName(), targetTypeDecl)) - { - if (auto typeDefDecl = targetTypeDecl->As<TypeDefDecl>()) - return GetType(DeclRef<TypeDefDecl>(typeDefDecl, subst)); - else - return DeclRefType::Create(this->getSession(), DeclRef<Decl>(targetTypeDecl, subst)); - } - } - } - - return this; - } - - int AssocTypeDeclRefType::GetHashCode() - { - return declRef.GetHashCode(); - } - - Type* AssocTypeDeclRefType::CreateCanonicalType() - { - return this; - } - - // GenericConstraintDeclRefType - - String GenericConstraintDeclRefType::ToString() - { - // TODO: what is appropriate here? - return "<GenericConstraintType>"; - } - - bool GenericConstraintDeclRefType::EqualsImpl(Type * type) - { - if (auto other = type->As<GenericConstraintDeclRefType>()) - { - return supType->Equals(other->supType) && subType->Equals(other->subType); - } - return false; - } - - RefPtr<Val> GenericConstraintDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff) - { - return subType->SubstituteImpl(subst, ioDiff); - } - - int GenericConstraintDeclRefType::GetHashCode() - { - return combineHash(subType.GetHashCode(), supType.GetHashCode()); - } - - Type* GenericConstraintDeclRefType::CreateCanonicalType() - { - return this; - } - // ArithmeticExpressionType // VectorExpressionType @@ -1091,24 +1003,24 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* MatrixExpressionType::getElementType() { - return this->declRef.substitutions->args[0].As<Type>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr(); } IntVal* MatrixExpressionType::getRowCount() { - return this->declRef.substitutions->args[1].As<IntVal>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[1].As<IntVal>().Ptr(); } IntVal* MatrixExpressionType::getColumnCount() { - return this->declRef.substitutions->args[2].As<IntVal>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[2].As<IntVal>().Ptr(); } // PtrTypeBase Type* PtrTypeBase::getValueType() { - return this->declRef.substitutions->args[0].As<Type>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr(); } // GenericParamIntVal @@ -1137,31 +1049,34 @@ void Type::accept(IValVisitor* visitor, void* extra) // search for a substitution that might apply to us for (auto s = subst; s; s = s->outer.Ptr()) { - // the generic decl associated with the substitution list must be - // the generic decl that declared this parameter - auto genericDecl = s->genericDecl; - if (genericDecl != declRef.getDecl()->ParentDecl) - continue; - - int index = 0; - for (auto m : genericDecl->Members) + if (auto genSubst = dynamic_cast<GenericSubstitution*>(s)) { - if (m.Ptr() == declRef.getDecl()) - { - // We've found it, so return the corresponding specialization argument - (*ioDiff)++; - return s->args[index]; - } - else if(auto typeParam = m.As<GenericTypeParamDecl>()) - { - index++; - } - else if(auto valParam = m.As<GenericValueParamDecl>()) - { - index++; - } - else + // the generic decl associated with the substitution list must be + // the generic decl that declared this parameter + auto genericDecl = genSubst->genericDecl; + if (genericDecl != declRef.getDecl()->ParentDecl) + continue; + + int index = 0; + for (auto m : genericDecl->Members) { + if (m.Ptr() == declRef.getDecl()) + { + // We've found it, so return the corresponding specialization argument + (*ioDiff)++; + return genSubst->args[index]; + } + else if (auto typeParam = m.As<GenericTypeParamDecl>()) + { + index++; + } + else if (auto valParam = m.As<GenericValueParamDecl>()) + { + index++; + } + else + { + } } } } @@ -1172,12 +1087,12 @@ void Type::accept(IValVisitor* visitor, void* extra) // Substitutions - RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff) + RefPtr<Substitutions> GenericSubstitution::SubstituteImpl(Substitutions* subst, int* ioDiff) { if (!this) return nullptr; int diff = 0; - auto outerSubst = outer->SubstituteImpl(subst, &diff); + auto outerSubst = outer ? outer->SubstituteImpl(subst, &diff) : nullptr; List<RefPtr<Val>> substArgs; for (auto a : args) @@ -1188,35 +1103,73 @@ void Type::accept(IValVisitor* visitor, void* extra) if (!diff) return this; (*ioDiff)++; - auto substSubst = new Substitutions(); + auto substSubst = new GenericSubstitution(); substSubst->genericDecl = genericDecl; substSubst->args = substArgs; return substSubst; } - bool Substitutions::Equals(Substitutions* subst) + bool GenericSubstitution::Equals(Substitutions* subst) { // both must be NULL, or non-NULL if (!this || !subst) return !this && !subst; - - if (genericDecl != subst->genericDecl) + auto genericSubst = dynamic_cast<GenericSubstitution*>(subst); + if (!genericSubst) + return false; + if (genericDecl != genericSubst->genericDecl) return false; UInt argCount = args.Count(); - SLANG_RELEASE_ASSERT(args.Count() == subst->args.Count()); + SLANG_RELEASE_ASSERT(args.Count() == genericSubst->args.Count()); for (UInt aa = 0; aa < argCount; ++aa) { - if (!args[aa]->EqualsVal(subst->args[aa].Ptr())) + if (!args[aa]->EqualsVal(genericSubst->args[aa].Ptr())) return false; } + if (!outer) + return !subst->outer; + if (!outer->Equals(subst->outer.Ptr())) return false; return true; } + RefPtr<Substitutions> ThisTypeSubstitution::SubstituteImpl(Substitutions* subst, int* ioDiff) + { + if (!this) return nullptr; + + int diff = 0; + auto outerSubst = outer->SubstituteImpl(subst, &diff); + + auto newSourceType = sourceType->SubstituteImpl(subst, ioDiff); + if (!diff) return this; + + (*ioDiff)++; + auto substSubst = new ThisTypeSubstitution(); + substSubst->sourceType = newSourceType; + substSubst->outer = outerSubst; + return substSubst; + } + + bool ThisTypeSubstitution::Equals(Substitutions* subst) + { + // both must be NULL, or non-NULL + if (!this || !subst) + return !this && !subst; + auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst); + if (!thisSubst) + return false; + if (!thisSubst->sourceType->EqualsVal(sourceType)) + return false; + if (!outer->Equals(subst->outer.Ptr())) + return false; + return true; + } + + // DeclRefBase @@ -1248,8 +1201,6 @@ void Type::accept(IValVisitor* visitor, void* extra) return expr; SLANG_UNIMPLEMENTED_X("generic substitution into expressions"); - - return expr; } @@ -1277,7 +1228,8 @@ void Type::accept(IValVisitor* visitor, void* extra) { if (decl != declRef.decl) return false; - + if (!substitutions) + return !declRef.substitutions; if (!substitutions->Equals(declRef.substitutions.Ptr())) return false; @@ -1298,7 +1250,8 @@ void Type::accept(IValVisitor* visitor, void* extra) if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl)) { - if (substitutions && substitutions->genericDecl == parentDecl) + auto genSubst = substitutions.As<GenericSubstitution>(); + if (genSubst && genSubst->genericDecl == parentDecl) { // We strip away the specializations that were applied to // the parent, since we were asked for a reference *to* the parent. @@ -1427,7 +1380,6 @@ void Type::accept(IValVisitor* visitor, void* extra) else { SLANG_UNEXPECTED("unhandled syntax class name"); - return nullptr; } } @@ -1437,12 +1389,12 @@ void Type::accept(IValVisitor* visitor, void* extra) Type* HLSLPatchType::getElementType() { - return this->declRef.substitutions->args[0].As<Type>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr(); } IntVal* HLSLPatchType::getElementCount() { - return this->declRef.substitutions->args[1].As<IntVal>().Ptr(); + return this->declRef.substitutions.As<GenericSubstitution>()->args[1].As<IntVal>().Ptr(); } // Constructors for types diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 2052e1eb8..60f519c82 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -489,59 +489,4 @@ protected: virtual int GetHashCode() override; virtual Type* CreateCanonicalType() override; ) -END_SYNTAX_CLASS() - -// The "type" of an expression that references a asscoiated type decl (via 'assoctype' keyword). -SYNTAX_CLASS(AssocTypeDeclRefType, Type) - DECL_FIELD(DeclRef<AssocTypeDecl>, declRef) - DECL_FIELD(RefPtr<Type>, sourceType) - RAW( - AssocTypeDeclRefType() - {} - AssocTypeDeclRefType( - DeclRef<AssocTypeDecl> declRef) - : declRef(declRef) - {} - - - DeclRef<AssocTypeDecl> const& GetDeclRef() const { return declRef; } - - virtual String ToString() override; - - protected: - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual Type* CreateCanonicalType() override; - ) -END_SYNTAX_CLASS() - -// The "type" of an generic constraint, which wraps both the sub and sup (interface) type -// the sub type can be used in associated type substitution in later type evaluation -SYNTAX_CLASS(GenericConstraintDeclRefType, Type) - DECL_FIELD(RefPtr<Type>, subType) - DECL_FIELD(RefPtr<Type>, supType) -RAW( - GenericConstraintDeclRefType() - {} - GenericConstraintDeclRefType(Session* session, - RefPtr<Type> sub, - RefPtr<Type> sup) - : subType(sub), supType(sup) - { - setSession(session); - } - - - RefPtr<Type> const& GetSupType() const { return supType; } - RefPtr<Type> const& GetSubType() const { return subType; } - - virtual String ToString() override; - - protected: - virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override; - virtual bool EqualsImpl(Type * type) override; - virtual int GetHashCode() override; - virtual Type* CreateCanonicalType() override; -) END_SYNTAX_CLASS()
\ No newline at end of file diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp index f129d15e0..d5bacaa36 100644 --- a/source/slang/vm.cpp +++ b/source/slang/vm.cpp @@ -513,8 +513,8 @@ void computeTypeSizeAlign( break; default: - SLANG_UNIMPLEMENTED_X("type sizing"); impl->size = 0; + SLANG_UNIMPLEMENTED_X("type sizing"); break; } @@ -528,7 +528,7 @@ void computeTypeSizeAlign( } VMType getType( - VM* vm, + VM* /*vm*/, VMTypeImpl* typeImpl) { // TODO: need to look up an existing type that matches... @@ -587,7 +587,7 @@ VMType loadVMType( VMTypeImpl* impl = (VMTypeImpl*) alloca(size); memset(impl, 0, size); impl->op = bcType->op; - impl->argCount = argCount; + impl->argCount = (uint32_t)argCount; VMVal* args = (VMVal*) (impl + 1); for(UInt aa = 0; aa < argCount; ++aa) @@ -597,14 +597,11 @@ VMType loadVMType( return getType(vmModule->vm, impl); } - - SLANG_UNEXPECTED("unimplemented"); - return VMType(); break; } } -void* allocateImpl(VM* vm, UInt size, UInt align) +void* allocateImpl(VM* /*vm*/, UInt size, UInt /*align*/) { void* ptr = malloc(size); memset(ptr, 0, size); @@ -666,7 +663,7 @@ void* loadVMSymbol( VMModule* loadVMModuleInstance( VM* vm, void const* bytecode, - size_t bytecodeSize) + size_t /*bytecodeSize*/) { BCHeader* bcHeader = (BCHeader*) bytecode; @@ -732,14 +729,14 @@ void* findGlobalSymbolPtr( continue; if(strcmp(symbolName, name) == 0) - return getGlobalPtr(module, ss); + return getGlobalPtr(module, (uint32_t)ss); } return nullptr; } VMThread* createThread( - VM* vm) + VM* /*vm*/) { VMThread* thread = new VMThread(); thread->frame = nullptr; @@ -863,7 +860,7 @@ void resumeThread( case kIROp_BufferStore: { VMType resultType = decodeType(frame, &ip); - UInt argCount = decodeUInt(&ip); + /*UInt argCount =*/ decodeUInt(&ip); char* bufferData = decodeOperand<char*>(frame, &ip); uint32_t index = decodeOperand<uint32_t>(frame, &ip); @@ -944,10 +941,9 @@ void resumeThread( case kIROp_ReturnVal: { VMType instType = decodeType(frame, &ip); - UInt argCount = decodeUInt(&ip); + /*UInt argCount =*/ decodeUInt(&ip); void* argPtr = decodeOperandPtr<void>(frame, &ip); - VMFrame* oldFrame = frame; VMFrame* newFrame = frame->parent; vmThread->frame = newFrame; @@ -980,7 +976,7 @@ void resumeThread( Int destinationBlock = decodeSInt(&ip); for( UInt aa = 2; aa < argCount; ++aa ) { - void* argPtr = decodeOperandPtr<void>(frame, &ip); + decodeOperandPtr<void>(frame, &ip); } // TODO: we need to deal with the case of @@ -1006,7 +1002,7 @@ void resumeThread( Int falseBlockID = decodeSInt(&ip); for( UInt aa = 4; aa < argCount; ++aa ) { - void* argPtr = decodeOperandPtr<void>(frame, &ip); + decodeOperandPtr<void>(frame, &ip); } Int destinationBlock = *condition ? trueBlockID : falseBlockID; @@ -1025,7 +1021,7 @@ void resumeThread( // knowing too much about an instruction... VMType resultType = decodeType(frame, &ip); - UInt argCount = decodeUInt(&ip); + /*UInt argCount =*/ decodeUInt(&ip); void* argPtrs[16] = { 0 }; auto leftOpnd = decodeOperandPtrAndType(frame, &ip); auto type = leftOpnd.type; @@ -1050,7 +1046,7 @@ void resumeThread( case kIROp_Mul: { VMType type = decodeType(frame, &ip); - UInt argCount = decodeUInt(&ip); + /*UInt argCount =*/ decodeUInt(&ip); void* leftPtr = decodeOperandPtr<void>(frame, &ip); void* rightPtr = decodeOperandPtr<void>(frame, &ip); @@ -1072,7 +1068,7 @@ void resumeThread( case kIROp_Sub: { VMType type = decodeType(frame, &ip); - UInt argCount = decodeUInt(&ip); + /*UInt argCount =*/ decodeUInt(&ip); void* leftPtr = decodeOperandPtr<void>(frame, &ip); void* rightPtr = decodeOperandPtr<void>(frame, &ip); diff --git a/tests/compute/assoctype-complex.slang b/tests/compute/assoctype-complex.slang index de3f1a103..f29d231b6 100644 --- a/tests/compute/assoctype-complex.slang +++ b/tests/compute/assoctype-complex.slang @@ -30,30 +30,18 @@ struct Simple : ISimple return v0.sub(4, v1.sub(1,2)); } }; -/* + __generic<T:ISimple> T.U.V test(T simple, T.U v0, T.U v1) { return simple.add(v0, v1); } -__generic<T:__BuiltinArithmeticType> -T test(T v0, T v1) -{ - return v0 + v1; -} -*/ -__generic<T:__BuiltinFloatingPointType> -T test(T v0, T v1) -{ - return T(3.0); -} [numthreads(4, 1, 1)] void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { - //Simple s; - //Val v0, v1; - //float outVal = test(s, v0, v1); // == 1.0 - float outVal = test<float>(1.0, 2.0); + Simple s; + Val v0, v1; + float outVal = test(s, v0, v1); // == 1.0 outputBuffer[dispatchThreadID.x] = outVal; }
\ No newline at end of file diff --git a/tests/compute/generics-constraint1.slang b/tests/compute/generics-constraint1.slang index ff90c1cc9..aa8d398e8 100644 --- a/tests/compute/generics-constraint1.slang +++ b/tests/compute/generics-constraint1.slang @@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer; __generic<T:__BuiltinFloatingPointType> T test(T v0, T v1) { - return T(3.0); + return v0; } [numthreads(4, 1, 1)] diff --git a/tests/compute/generics-constructor.slang b/tests/compute/generics-constructor.slang new file mode 100644 index 000000000..c7473cc8b --- /dev/null +++ b/tests/compute/generics-constructor.slang @@ -0,0 +1,17 @@ +//TEST(smoke, compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<float> outputBuffer; + +__generic<T:__BuiltinFloatingPointType> +T test(T v0, T v1) +{ + return T(3.0); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float outVal = test<float>(1.0, 2.0); + outputBuffer[dispatchThreadID.x] = outVal; +}
\ No newline at end of file |
