summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/core/core.natvis2
-rw-r--r--source/slang/bytecode.cpp56
-rw-r--r--source/slang/bytecode.h2
-rw-r--r--source/slang/check.cpp118
-rw-r--r--source/slang/decl-defs.h2
-rw-r--r--source/slang/emit.cpp11
-rw-r--r--source/slang/ir.cpp97
-rw-r--r--source/slang/lookup.cpp32
-rw-r--r--source/slang/lower-to-ir.cpp40
-rw-r--r--source/slang/lower.cpp43
-rw-r--r--source/slang/mangle.cpp6
-rw-r--r--source/slang/slang.natvis27
-rw-r--r--source/slang/syntax-base-defs.h54
-rw-r--r--source/slang/syntax.cpp340
-rw-r--r--source/slang/type-defs.h55
-rw-r--r--source/slang/vm.cpp32
-rw-r--r--tests/compute/assoctype-complex.slang20
-rw-r--r--tests/compute/generics-constraint1.slang2
-rw-r--r--tests/compute/generics-constructor.slang17
19 files changed, 434 insertions, 522 deletions
diff --git a/source/core/core.natvis b/source/core/core.natvis
index 3d9ac702e..91fdbb49b 100644
--- a/source/core/core.natvis
+++ b/source/core/core.natvis
@@ -54,7 +54,7 @@
</Expand>
</Type>
-<Type Name="Slang::RefPtrImpl&lt;*,*,*&gt;">
+<Type Name="Slang::RefPtr&lt;*&gt;">
<SmartPointer Usage="Minimal">pointer</SmartPointer>
<DisplayString Condition="pointer == 0">empty</DisplayString>
<DisplayString Condition="pointer != 0">RefPtr {*pointer}</DisplayString>
diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp
index f1b23849d..ee055a01f 100644
--- a/source/slang/bytecode.cpp
+++ b/source/slang/bytecode.cpp
@@ -77,7 +77,6 @@ struct BytecodeGenerationPtr
BytecodeGenerationPtr<T> operator+(Int index) const
{
- UInt size = sizeof(T);
Int delta = index * sizeof(T);
UInt newOffset = offset + delta;
return BytecodeGenerationPtr<T>(
@@ -157,7 +156,7 @@ BCPtr<void>::RawVal allocateRaw(
for(size_t ii = currentOffset; ii < endOffset; ++ii)
context->shared->bytecode.Add(0);
- return beginOffset;
+ return (BCPtr<void>::RawVal)beginOffset;
}
template<typename T>
@@ -196,7 +195,7 @@ void encodeUInt(
{
if( value < 128 )
{
- encodeUInt8(context, value);
+ encodeUInt8(context, (uint8_t)value);
return;
}
@@ -256,9 +255,8 @@ BCConst getGlobalValue(
UInt constID = context->shared->constants.Count();
context->shared->constants.Add(value);
- BCConst bcConst;
bcConst.flavor = kBCConstFlavor_Constant;
- bcConst.id = constID;
+ bcConst.id = (uint32_t)constID;
context->shared->mapValueToGlobal.Add(value, bcConst);
@@ -270,10 +268,10 @@ BCConst getGlobalValue(
break;
}
+ bcConst.flavor = (uint32_t) -1;
+ bcConst.id = (uint32_t)-9999;
SLANG_UNEXPECTED("no ID for inst");
- bcConst.flavor = (BCConstFlavor) -1;
- bcConst.id = -9999;
- return bcConst;
+ //return bcConst;
}
Int getLocalID(
@@ -348,7 +346,6 @@ void generateBytecodeForInst(
//
auto argCount = inst->getArgCount();
- auto type = inst->getType();
encodeUInt(context, inst->op);
encodeOperand(context, inst->getType());
encodeUInt(context, argCount);
@@ -400,9 +397,9 @@ void generateBytecodeForInst(
unsigned char buffer[size];
memcpy(buffer, &ii->u.floatVal, sizeof(buffer));
- for(UInt ii = 0; ii < size; ++ii)
+ for(UInt i = 0; i < size; ++i)
{
- encodeUInt8(context, buffer[ii]);
+ encodeUInt8(context, buffer[i]);
}
// destination:
@@ -479,7 +476,7 @@ BytecodeGenerationPtr<BCType> emitBCType(
auto bcArgs = (bcType + 1).bitCast<BCPtr<uint8_t>>();
bcType->op = op;
- bcType->argCount = argCount;
+ bcType->argCount = (uint32_t)argCount;
for(UInt aa = 0; aa < argCount; ++aa)
{
@@ -489,7 +486,7 @@ BytecodeGenerationPtr<BCType> emitBCType(
UInt id = context->shared->bcTypes.Count();
context->shared->mapTypeToID.Add(type, id);
context->shared->bcTypes.Add(bcType);
- bcType->id = id;
+ bcType->id = (uint32_t)id;
return bcType;
}
@@ -577,7 +574,6 @@ BytecodeGenerationPtr<BCType> emitBCTypeImpl(
SLANG_UNEXPECTED("unimplemented");
- return BytecodeGenerationPtr<BCType>();
}
BytecodeGenerationPtr<BCType> emitBCType(
@@ -703,7 +699,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
// Allocate the array of block objects to be stored in the
// bytecode file.
auto bcBlocks = allocateArray<BCBlock>(context, blockCount);
- bcFunc->blockCount = blockCount;
+ bcFunc->blockCount = (uint32_t)blockCount;
bcFunc->blocks = bcBlocks;
// Now loop through the blocks again, and allocate the storage
@@ -750,7 +746,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
}
- bcBlocks[blockID].paramCount = paramCount;
+ bcBlocks[blockID].paramCount = (uint32_t)paramCount;
}
// Okay, we've counted how many registers we need for each block,
@@ -758,7 +754,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
UInt regCount = regCounter;
auto bcRegs = allocateArray<BCReg>(context, regCount);
- bcFunc->regCount = regCount;
+ bcFunc->regCount = (uint32_t)regCount;
bcFunc->regs = bcRegs;
// Now we will loop over things again to fill in the information
@@ -786,7 +782,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, pp);
#endif
- bcRegs[localID].previousVarIndexPlusOne = localID;
+ bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID;
bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, pp);
}
@@ -808,7 +804,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
#endif
- bcRegs[localID].previousVarIndexPlusOne = localID;
+ bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID;
bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
}
break;
@@ -828,11 +824,11 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
#endif
- bcRegs[localID].previousVarIndexPlusOne = localID;
+ bcRegs[localID].previousVarIndexPlusOne = (uint32_t)localID;
bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
bcRegs[localID+1].op = ii->op;
- bcRegs[localID+1].previousVarIndexPlusOne = localID+1;
+ bcRegs[localID+1].previousVarIndexPlusOne = (uint32_t)localID+1;
bcRegs[localID+1].typeID = getTypeID(context,
(ii->getType()->As<PtrType>())->getValueType());
}
@@ -840,7 +836,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
}
}
- assert(regCounter == regCount);
+ assert((UInt)regCounter == regCount);
// Now that we've allocated our blocks and our registers
// we can go through the actual process of emitting instructions. Hooray!
@@ -851,7 +847,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
List<UInt> blockOffsets;
for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
{
- UInt blockID = blockCounter++;
+ blockCounter++;
// Get local bytecode offset for current block.
UInt blockOffset = subContext->currentBytecode.Count();
@@ -903,7 +899,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
UInt constCount = subContext->remappedGlobalSymbols.Count();
auto bcConsts = allocateArray<BCConst>(context, constCount);
- bcFunc->constCount = constCount;
+ bcFunc->constCount = (uint32_t)constCount;
bcFunc->consts = bcConsts;
for( UInt cc = 0; cc < constCount; ++cc )
@@ -969,7 +965,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
// Ensure that local code inside functions can see these symbols
BCConst bcConst;
bcConst.flavor = kBCConstFlavor_GlobalSymbol;
- bcConst.id = globalID;
+ bcConst.id = (uint32_t)globalID;
context->shared->mapValueToGlobal.Add(gv, bcConst);
// In the global scope, global IDs are also the local IDs
@@ -978,7 +974,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
auto bcSymbols = allocateArray<BCPtr<BCSymbol>>(context, symbolCount);
- bcModule->symbolCount = symbolCount;
+ bcModule->symbolCount = (uint32_t)symbolCount;
bcModule->symbols = bcSymbols;
for( auto gv = irModule->getFirstGlobalValue(); gv; gv = gv->getNextValue() )
@@ -998,7 +994,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
// At this point we should have identified all the literals we need:
UInt constantCount = context->shared->constants.Count();
auto bcConstants = allocateArray<BCConstant>(context, constantCount);
- bcModule->constantCount = constantCount;
+ bcModule->constantCount = (uint32_t)constantCount;
bcModule->constants = bcConstants;
for(UInt cc = 0; cc < constantCount; ++cc)
@@ -1026,7 +1022,7 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
// At this point we should have collected all the types we need:
UInt typeCount = context->shared->bcTypes.Count();
auto bcTypes = allocateArray<BCPtr<BCType>>(context, typeCount);
- bcModule->typeCount = typeCount;
+ bcModule->typeCount = (uint32_t)typeCount;
bcModule->types = bcTypes;
for(UInt tt = 0; tt < typeCount; ++tt)
@@ -1055,8 +1051,6 @@ void generateBytecodeContainer(
// TODO: Need to dump BC representation of compiled kernel codes
// for each specified code-generation target.
- UInt translationUnitCount = compileReq->translationUnits.Count();
-
List<BytecodeGenerationPtr<BCModule>> bcModulesList;
for (auto translationUnitReq : compileReq->translationUnits)
{
@@ -1065,7 +1059,7 @@ void generateBytecodeContainer(
}
UInt bcModuleCount = bcModulesList.Count();
- header->moduleCount = bcModuleCount;
+ header->moduleCount = (uint32_t)bcModuleCount;
auto bcModules = allocateArray<BCPtr<BCModule>>(context, bcModuleCount);
header->modules = bcModules;
diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h
index 75b9f15cd..f1ad52c32 100644
--- a/source/slang/bytecode.h
+++ b/source/slang/bytecode.h
@@ -56,7 +56,7 @@ struct BCPtr
{
if (ptr)
{
- rawVal = (char*)ptr - (char*)this;
+ rawVal = (RawVal)((char*)ptr - (char*)this);
}
else
{
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index dc4e48545..137b1c451 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -148,15 +148,13 @@ namespace Slang
if (baseExpr)
{
RefPtr<Expr> expr;
-
+
if (baseExpr->type->As<TypeType>())
{
auto sexpr = new StaticMemberExpr();
sexpr->loc = loc;
sexpr->BaseExpression = baseExpr;
sexpr->name = declRef.GetName();
- sexpr->type = GetTypeForDeclRef(declRef);
- sexpr->declRef = declRef;
expr = sexpr;
}
else
@@ -165,54 +163,24 @@ namespace Slang
sexpr->loc = loc;
sexpr->BaseExpression = baseExpr;
sexpr->name = declRef.GetName();
- sexpr->type = GetTypeForDeclRef(declRef);
sexpr->declRef = declRef;
expr = sexpr;
}
- if (auto constraintType = expr->type->As<GenericConstraintDeclRefType>())
+ if (auto assocTypeDecl = declRef.As<AssocTypeDecl>())
{
- if (baseExpr->type->As<TypeType>())
- constraintType->subType = baseExpr->type->As<TypeType>()->type;
- else
- constraintType->subType = baseExpr->type;
-
+ RefPtr<ThisTypeSubstitution> subst = new ThisTypeSubstitution();
+ subst->sourceType = baseExpr->type.type;
+ if (auto typeType = subst->sourceType.As<TypeType>())
+ subst->sourceType = typeType->type;
+ expr->type = GetTypeForDeclRef(DeclRef<AssocTypeDecl>(assocTypeDecl.getDecl(), subst));
}
-
- if (auto genConstraintType = baseExpr->type->As<GenericConstraintDeclRefType>())
+ else if (auto constraintDecl = declRef.As<GenericTypeConstraintDecl>())
{
- if (auto funcDeclRef = declRef.As<CallableDecl>())
- {
- // if this is call expression, propagate the source associated type to the result type
- auto funcType = expr->type->As<FuncType>();
- if (auto assocRsType = funcType->resultType.As<AssocTypeDeclRefType>())
- {
- RefPtr<FuncType> newFuncType = new FuncType();
- newFuncType->paramTypes = funcType->paramTypes;
- RefPtr<AssocTypeDeclRefType> newRsType = new AssocTypeDeclRefType();
- newRsType->declRef = assocRsType->declRef;
- newRsType->sourceType = genConstraintType->subType;
- newRsType->setSession(getSession());
- newFuncType->resultType = newRsType;
- newFuncType->setSession(funcType->getSession());
- expr->type = QualType(newFuncType);
- }
- }
- else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
- {
- auto assocTypeDeclType = new AssocTypeDeclRefType();
- assocTypeDeclType->declRef = assocTypeDeclRef;
- assocTypeDeclType->sourceType = genConstraintType->subType;
- assocTypeDeclType->setSession(getSession());
- expr->type = QualType(getTypeType(assocTypeDeclType));
- }
+ expr->type = baseExpr->type;
}
- else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
+ else
{
- auto assocTypeDeclType = new AssocTypeDeclRefType();
- assocTypeDeclType->declRef = assocTypeDeclRef;
- assocTypeDeclType->sourceType = baseExpr->type;
- assocTypeDeclType->setSession(getSession());
- expr->type = QualType(getTypeType(assocTypeDeclType));
+ expr->type = GetTypeForDeclRef(declRef);
}
return expr;
}
@@ -449,7 +417,7 @@ namespace Slang
DeclRef<GenericDecl> genericDeclRef,
List<RefPtr<Expr>> const& args)
{
- RefPtr<Substitutions> subst = new Substitutions();
+ RefPtr<GenericSubstitution> subst = new GenericSubstitution();
subst->genericDecl = genericDeclRef.getDecl();
subst->outer = genericDeclRef.substitutions;
@@ -2097,10 +2065,10 @@ namespace Slang
return true;
}
- RefPtr<Substitutions> createDummySubstitutions(
+ RefPtr<GenericSubstitution> createDummySubstitutions(
GenericDecl* genericDecl)
{
- RefPtr<Substitutions> subst = new Substitutions();
+ RefPtr<GenericSubstitution> subst = new GenericSubstitution();
subst->genericDecl = genericDecl;
for (auto dd : genericDecl->Members)
{
@@ -3115,7 +3083,7 @@ namespace Slang
session, "Vector").As<GenericDecl>();
auto vectorTypeDecl = vectorGenericDecl->inner;
- auto substitutions = new Substitutions();
+ auto substitutions = new GenericSubstitution();
substitutions->genericDecl = vectorGenericDecl.Ptr();
substitutions->args.Add(elementType);
substitutions->args.Add(elementCount);
@@ -3447,7 +3415,7 @@ namespace Slang
// to `interfaceDeclRef`.
//
SLANG_UNEXPECTED("reflexive type witness");
- return nullptr;
+ //return nullptr;
}
auto breadcrumbs = inBreadcrumbs;
@@ -3462,7 +3430,7 @@ namespace Slang
// because `A : B` and `B : C` then `A : C`
//
SLANG_UNEXPECTED("transitive type witness");
- return nullptr;
+ //return nullptr;
}
// Simple case: we have a single declaration
@@ -3815,7 +3783,7 @@ namespace Slang
// Consruct a reference to the extension with our constraint variables
// as the
- RefPtr<Substitutions> solvedSubst = new Substitutions();
+ RefPtr<GenericSubstitution> solvedSubst = new GenericSubstitution();
solvedSubst->genericDecl = genericDeclRef.getDecl();
solvedSubst->outer = genericDeclRef.substitutions;
solvedSubst->args = args;
@@ -4084,8 +4052,9 @@ namespace Slang
// We will go ahead and hang onto the arguments that we've
// already checked, since downstream validation might need
// them.
- candidate.subst = new Substitutions();
- auto& checkedArgs = candidate.subst->args;
+ auto genSubst = new GenericSubstitution();
+ candidate.subst = genSubst;
+ auto& checkedArgs = genSubst->args;
int aa = 0;
for (auto memberRef : getMembers(genericDeclRef))
@@ -4202,7 +4171,7 @@ namespace Slang
// Create a witness that attests to the fact that `type`
// is equal to itself.
RefPtr<Val> createTypeEqualityWitness(
- Type* type)
+ Type* /*type*/)
{
SLANG_UNEXPECTED("unimplemented");
}
@@ -4258,7 +4227,7 @@ namespace Slang
// We should have the existing arguments to the generic
// handy, so that we can construct a substitution list.
- RefPtr<Substitutions> subst = candidate.subst;
+ RefPtr<GenericSubstitution> subst = candidate.subst.As<GenericSubstitution>();
assert(subst);
subst->genericDecl = genericDeclRef.getDecl();
@@ -4325,7 +4294,7 @@ namespace Slang
RefPtr<Expr> createGenericDeclRef(
RefPtr<Expr> baseExpr,
RefPtr<Expr> originalExpr,
- RefPtr<Substitutions> subst)
+ RefPtr<GenericSubstitution> subst)
{
auto baseDeclRefExpr = baseExpr.As<DeclRefExpr>();
if (!baseDeclRefExpr)
@@ -4437,7 +4406,7 @@ namespace Slang
return createGenericDeclRef(
baseExpr,
context.originalExpr,
- candidate.subst);
+ candidate.subst.As<GenericSubstitution>());
break;
default:
@@ -4734,22 +4703,23 @@ namespace Slang
// They must both be NULL or non-NULL
if (!fst || !snd)
return fst == snd;
-
+ auto fstGen = fst.As<GenericSubstitution>();
+ auto sndGen = snd.As<GenericSubstitution>();
// They must be specializing the same generic
- if (fst->genericDecl != snd->genericDecl)
+ if (fstGen->genericDecl != sndGen->genericDecl)
return false;
// Their arguments must unify
- SLANG_RELEASE_ASSERT(fst->args.Count() == snd->args.Count());
- UInt argCount = fst->args.Count();
+ SLANG_RELEASE_ASSERT(fstGen->args.Count() == sndGen->args.Count());
+ UInt argCount = fstGen->args.Count();
for (UInt aa = 0; aa < argCount; ++aa)
{
- if (!TryUnifyVals(constraints, fst->args[aa], snd->args[aa]))
+ if (!TryUnifyVals(constraints, fstGen->args[aa], sndGen->args[aa]))
return false;
}
// Their "base" specializations must unify
- if (!TryUnifySubstitutions(constraints, fst->outer, snd->outer))
+ if (!TryUnifySubstitutions(constraints, fstGen->outer, sndGen->outer))
return false;
return true;
@@ -5304,11 +5274,12 @@ namespace Slang
if( parentGenericDeclRef )
{
SLANG_RELEASE_ASSERT(declRef.substitutions);
- SLANG_RELEASE_ASSERT(declRef.substitutions->genericDecl == parentGenericDeclRef.getDecl());
+ auto genSubst = declRef.substitutions.As<GenericSubstitution>();
+ SLANG_RELEASE_ASSERT(genSubst->genericDecl == parentGenericDeclRef.getDecl());
sb << "<";
bool first = true;
- for(auto arg : declRef.substitutions->args)
+ for(auto arg : genSubst->args)
{
if(!first) sb << ", ";
formatVal(sb, arg);
@@ -6069,7 +6040,7 @@ namespace Slang
RefPtr<Expr> visitStaticMemberExpr(StaticMemberExpr* expr)
{
SLANG_UNEXPECTED("should not occur in unchecked AST");
- return expr;
+ //return expr;
}
RefPtr<Expr> lookupResultFailure(
@@ -6495,26 +6466,11 @@ namespace Slang
*outTypeResult = type;
return QualType(getTypeType(type));
}
- else if (auto constraintDeclRef = declRef.As<GenericTypeConstraintDecl>())
- {
- // When we access a constraint or an inheritance decl (as a member),
- // we are conceptually performing a "cast" to the given super-type,
- // with the declaration showing that such a cast is legal.
- auto type = new GenericConstraintDeclRefType(session, GetSub(constraintDeclRef), GetSup(constraintDeclRef));
- return QualType(type);
- }
else if (auto funcDeclRef = declRef.As<CallableDecl>())
{
auto type = getFuncType(session, funcDeclRef);
return QualType(type);
}
- else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
- {
- auto type = new AssocTypeDeclRefType(assocTypeDeclRef);
- type->setSession(session);
- *outTypeResult = type;
- return QualType(getTypeType(type));
- }
if( sink )
{
sink->diagnose(declRef, Diagnostics::unimplemented, "cannot form reference to this kind of declaration");
@@ -6558,7 +6514,7 @@ namespace Slang
if(decl != genericDecl->inner)
return parentSubst;
- RefPtr<Substitutions> subst = new Substitutions();
+ RefPtr<GenericSubstitution> subst = new GenericSubstitution();
subst->genericDecl = genericDecl;
subst->outer = parentSubst;
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index bb8c26f58..9c010d156 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -123,7 +123,7 @@ SYNTAX_CLASS(TypeDefDecl, SimpleTypeDecl)
END_SYNTAX_CLASS()
// An 'assoctype' declaration, it is a container of inheritance clauses
-SYNTAX_CLASS(AssocTypeDecl, ContainerDecl)
+SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl)
END_SYNTAX_CLASS()
// A scope for local declarations (e.g., as part of a statement)
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 9caad2e1a..2627e1f37 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -1052,15 +1052,6 @@ struct EmitVisitor
EmitDeclarator(arg.declarator);
}
- void visitAssocTypeDeclRefType(AssocTypeDeclRefType* /*type*/, TypeEmitArg const& /*arg*/)
- {
- //SLANG_UNREACHABLE("visitAssocTypeDeclRefType in EmitVisitor");
- }
- void visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* /*type*/, TypeEmitArg const& /*arg*/)
- {
- //SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in EmitVisitor");
- }
-
void visitBasicExpressionType(BasicExpressionType* basicType, TypeEmitArg const& arg)
{
auto declarator = arg.declarator;
@@ -2925,7 +2916,7 @@ struct EmitVisitor
return;
}
- Substitutions* subst = declRef.substitutions.Ptr();
+ GenericSubstitution* subst = declRef.substitutions.As<GenericSubstitution>().Ptr();
if (!subst)
return;
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 84ab3d9a5..5f3ca6c62 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -178,27 +178,27 @@ namespace Slang
//
// Add an instruction to a specific parent
- void IRBuilder::addInst(IRBlock* block, IRInst* inst)
+ void IRBuilder::addInst(IRBlock* pblock, IRInst* inst)
{
- inst->parent = block;
+ inst->parent = pblock;
- if (!block->firstInst)
+ if (!pblock->firstInst)
{
inst->prev = nullptr;
inst->next = nullptr;
- block->firstInst = inst;
- block->lastInst = inst;
+ pblock->firstInst = inst;
+ pblock->lastInst = inst;
}
else
{
- auto prev = block->lastInst;
+ auto prev = pblock->lastInst;
inst->prev = prev;
inst->next = nullptr;
prev->next = inst;
- block->lastInst = inst;
+ pblock->lastInst = inst;
}
}
@@ -206,7 +206,6 @@ namespace Slang
void IRBuilder::addInst(
IRInst* inst)
{
- auto insertBefore = insertBeforeInst;
if(insertBeforeInst)
{
inst->insertBefore(insertBeforeInst);
@@ -221,7 +220,7 @@ namespace Slang
}
static IRValue* createValueImpl(
- IRBuilder* builder,
+ IRBuilder* /*builder*/,
UInt size,
IROp op,
IRType* type)
@@ -249,7 +248,7 @@ namespace Slang
// arguments *after* the type (which is a mandatory
// argument for all instructions).
static IRInst* createInstImpl(
- IRBuilder* builder,
+ IRBuilder* /*builder*/,
UInt size,
IROp op,
IRType* type,
@@ -261,8 +260,7 @@ namespace Slang
IRInst* inst = (IRInst*) malloc(size);
memset(inst, 0, size);
- auto module = builder->getModule();
- inst->argCount = fixedArgCount + varArgCount;
+ inst->argCount = (uint32_t)(fixedArgCount + varArgCount);
inst->op = op;
@@ -632,7 +630,7 @@ namespace Slang
IRInst* IRBuilder::emitCallInst(
IRType* type,
- IRValue* func,
+ IRValue* pFunc,
UInt argCount,
IRValue* const* args)
{
@@ -641,7 +639,7 @@ namespace Slang
kIROp_Call,
type,
1,
- &func,
+ &pFunc,
argCount,
args);
addInst(inst);
@@ -829,12 +827,12 @@ namespace Slang
IRFunc* IRBuilder::createFunc()
{
- IRFunc* func = createValue<IRFunc>(
+ IRFunc* rsFunc = createValue<IRFunc>(
this,
kIROp_Func,
nullptr);
- addGlobalValue(getModule(), func);
- return func;
+ addGlobalValue(getModule(), rsFunc);
+ return rsFunc;
}
IRGlobalVar* IRBuilder::createGlobalVar(
@@ -1129,13 +1127,13 @@ namespace Slang
}
IRInst* IRBuilder::emitBranch(
- IRBlock* block)
+ IRBlock* pBlock)
{
auto inst = createInst<IRUnconditionalBranch>(
this,
kIROp_unconditionalBranch,
nullptr,
- block);
+ pBlock);
addInst(inst);
return inst;
}
@@ -1543,7 +1541,7 @@ namespace Slang
if(genericParentDeclRef)
{
- auto subst = declRef.substitutions;
+ auto subst = declRef.substitutions.As<GenericSubstitution>();
if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() )
{
// No actual substitutions in place here
@@ -1698,6 +1696,7 @@ namespace Slang
dumpChildrenRaw(context, block);
}
+#if 0
static void dumpChildrenRaw(
IRDumpContext* context,
IRFunc* func)
@@ -1720,6 +1719,7 @@ namespace Slang
dumpIndent(context);
dump(context, "}\n");
}
+#endif
static void dumpInst(
IRDumpContext* context,
@@ -2239,8 +2239,8 @@ namespace Slang
void IRInst::removeArguments()
{
- UInt argCount = this->argCount;
- for( UInt aa = 0; aa < argCount; ++aa )
+ UInt oldArgCount = this->argCount;
+ for( UInt aa = 0; aa < oldArgCount; ++aa )
{
IRUse& use = getArgs()[aa];
@@ -2384,9 +2384,9 @@ namespace Slang
IRBuilder* builder,
Type* type,
VarLayout* varLayout,
- TypeLayout* typeLayout,
- LayoutResourceKind kind,
- GlobalVaryingDeclarator* declarator)
+ TypeLayout* /*typeLayout*/,
+ LayoutResourceKind /*kind*/,
+ GlobalVaryingDeclarator* /*declarator*/)
{
// TODO: We might be creating an `in` or `out` variable based on
// an `in out` function parameter. In this case we should
@@ -2550,7 +2550,6 @@ namespace Slang
default:
SLANG_UNEXPECTED("unimplemented");
- return ScalarizedVal();
}
}
@@ -3079,7 +3078,6 @@ namespace Slang
break;
default:
SLANG_UNEXPECTED("no value registered for IR value");
- return nullptr;
}
}
@@ -3111,18 +3109,27 @@ namespace Slang
{
if (!subst)
return nullptr;
+ if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst))
+ {
+ RefPtr<GenericSubstitution> newSubst = new GenericSubstitution();
+ newSubst->outer = cloneSubstitutions(context, subst->outer);
+ newSubst->genericDecl = genSubst->genericDecl;
- RefPtr<Substitutions> newSubst = new Substitutions();
- newSubst->outer = cloneSubstitutions(context, subst->outer);
- newSubst->genericDecl = subst->genericDecl;
-
- for (auto arg : subst->args)
+ for (auto arg : genSubst->args)
+ {
+ auto newArg = cloneSubstitutionArg(context, arg);
+ newSubst->args.Add(arg);
+ }
+ return newSubst;
+ }
+ else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst))
{
- auto newArg = cloneSubstitutionArg(context, arg);
- newSubst->args.Add(arg);
+ RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution();
+ newSubst->sourceType = thisSubst->sourceType;
+ newSubst->outer = cloneSubstitutions(context, subst->outer);
+ return newSubst;
}
-
- return newSubst;
+ return nullptr;
}
DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef)
@@ -3231,7 +3238,7 @@ namespace Slang
{
auto clonedKey = context->maybeCloneValue(originalEntry->requirementKey.usedValue);
auto clonedVal = context->maybeCloneValue(originalEntry->satisfyingVal.usedValue);
- auto clonedEntry = context->builder->createWitnessTableEntry(
+ context->builder->createWitnessTableEntry(
clonedTable,
clonedKey,
clonedVal);
@@ -3416,7 +3423,6 @@ namespace Slang
default:
SLANG_UNEXPECTED("unhandled case");
- return "unknown";
}
}
@@ -3518,7 +3524,6 @@ namespace Slang
{
// This shouldn't happen!
SLANG_UNEXPECTED("no matching function registered");
- return cloneSimpleFunc(context, originalFunc);
}
// We will try to track the "best" definition we can find.
@@ -3748,7 +3753,6 @@ namespace Slang
else
{
SLANG_UNEXPECTED("unimplemented");
- return nullptr;
}
}
@@ -3756,7 +3760,8 @@ namespace Slang
IRGenericSpecContext* context,
DeclRef<Decl> declRef)
{
- auto subst = context->subst;
+ auto subst = context->subst.As<GenericSubstitution>();
+ SLANG_ASSERT(subst);
auto genericDecl = subst->genericDecl;
UInt orinaryParamCount = 0;
@@ -3806,12 +3811,13 @@ namespace Slang
{
auto declRefVal = (IRDeclRef*) originalVal;
auto declRef = declRefVal->declRef;
-
+ auto genSubst = subst.As<GenericSubstitution>();
+ SLANG_ASSERT(genSubst);
// We may have a direct reference to one of the parameters
// of the generic we are specializing, and in that case
// we nee to translate it over to the equiavalent of
// the `Val` we have been given.
- if(declRef.getDecl()->ParentDecl == subst->genericDecl)
+ if(declRef.getDecl()->ParentDecl == genSubst->genericDecl)
{
return getSubstValue(this, declRef);
}
@@ -3866,9 +3872,10 @@ namespace Slang
// using a different overload of a target-specific function,
// so we need to create a dummy substitution here, to make
// sure it used the correct generic.
- RefPtr<Substitutions> newSubst = new Substitutions();
+ RefPtr<GenericSubstitution> newSubst = new GenericSubstitution();
newSubst->genericDecl = genericFunc->genericDecl;
- newSubst->args = specDeclRef.substitutions->args;
+ auto specDeclRefSubst = specDeclRef.substitutions.As<GenericSubstitution>();
+ newSubst->args = specDeclRefSubst->args;
IRGenericSpecContext context;
context.shared = sharedContext;
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
index b97dac560..a5425074f 100644
--- a/source/slang/lookup.cpp
+++ b/source/slang/lookup.cpp
@@ -452,28 +452,28 @@ void lookUpMemberImpl(
lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
}
}
- }
- else if (auto assocTypeDeclRefType = type->As<AssocTypeDeclRefType>())
- {
- auto assocTypeDeclRef = assocTypeDeclRefType->declRef;
- for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
+ else if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
{
- // The super-type in the constraint (e.g., `Foo` in `T : Foo`)
- // will tell us a type we should use for lookup.
- auto bound = GetSup(constraintDeclRef);
+ for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
+ {
+ // The super-type in the constraint (e.g., `Foo` in `T : Foo`)
+ // will tell us a type we should use for lookup.
+ auto bound = GetSup(constraintDeclRef);
- // Go ahead and use the target type, with an appropriate breadcrumb
- // to indicate that we indirected through a type constraint.
+ // Go ahead and use the target type, with an appropriate breadcrumb
+ // to indicate that we indirected through a type constraint.
- BreadcrumbInfo breadcrumb;
- breadcrumb.prev = inBreadcrumbs;
- breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint;
- breadcrumb.declRef = constraintDeclRef;
+ BreadcrumbInfo breadcrumb;
+ breadcrumb.prev = inBreadcrumbs;
+ breadcrumb.kind = LookupResultItem::Breadcrumb::Kind::Constraint;
+ breadcrumb.declRef = constraintDeclRef;
- // TODO: Need to consider case where this might recurse infinitely.
- lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
+ // TODO: Need to consider case where this might recurse infinitely.
+ lookUpMemberImpl(session, semantics, name, bound, ioResult, &breadcrumb);
+ }
}
}
+
}
LookupResult lookUpMember(
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 998197279..6099df1ed 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -870,9 +870,12 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
auto subs = declRef.substitutions;
while(subs)
{
- for(auto aa : subs->args)
+ if (auto genSubst = subs.As<GenericSubstitution>())
{
- (*ioArgs).Add(getSimpleVal(context, lowerVal(context, aa)));
+ for (auto aa : genSubst->args)
+ {
+ (*ioArgs).Add(getSimpleVal(context, lowerVal(context, aa)));
+ }
}
subs = subs->outer;
}
@@ -3037,24 +3040,33 @@ RefPtr<Substitutions> lowerSubstitutions(
{
if(!subst)
return nullptr;
+ RefPtr<Substitutions> result;
+ if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst))
+ {
+ RefPtr<GenericSubstitution> newSubst = new GenericSubstitution();
+ newSubst->genericDecl = genSubst->genericDecl;
+
+ for (auto arg : genSubst->args)
+ {
+ auto newArg = lowerSubstitutionArg(context, arg);
+ newSubst->args.Add(newArg);
+ }
- RefPtr<Substitutions> newSubst = new Substitutions();
+ result = newSubst;
+ }
+ else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst))
+ {
+ RefPtr<ThisTypeSubstitution> newSubst = new ThisTypeSubstitution();
+ newSubst->sourceType = lowerSubstitutionArg(context, thisSubst->sourceType);
+ result = newSubst;
+ }
if (subst->outer)
{
- newSubst->outer = lowerSubstitutions(
+ result->outer = lowerSubstitutions(
context,
subst->outer);
}
-
- newSubst->genericDecl = subst->genericDecl;
-
- for (auto arg : subst->args)
- {
- auto newArg = lowerSubstitutionArg(context, arg);
- newSubst->args.Add(newArg);
- }
-
- return newSubst;
+ return result;
}
LoweredValInfo emitDeclRef(
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index b810d9643..b8696df47 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -779,20 +779,6 @@ struct LoweringVisitor
translateDeclRef(DeclRef<Decl>(type->declRef)).As<TypeDefDecl>());
}
- RefPtr<Type> visitGenericConstraintDeclRefType(GenericConstraintDeclRefType* type)
- {
- // not supported by lowering
- SLANG_UNREACHABLE("visitGenericConstraintDeclRefType in LowerVisitor");
- return nullptr;
- }
-
- RefPtr<Type> visitAssocTypeDeclRefType(AssocTypeDeclRefType* type)
- {
- // not supported by lowering
- SLANG_UNREACHABLE("visitAssocTypeDeclRefType in LowerVisitor");
- return nullptr;
- }
-
RefPtr<Type> visitTypeType(TypeType* type)
{
return getTypeType(lowerType(type->type));
@@ -2569,14 +2555,23 @@ struct LoweringVisitor
Substitutions* inSubstitutions)
{
if (!inSubstitutions) return nullptr;
-
- RefPtr<Substitutions> result = new Substitutions();
- result->genericDecl = translateDeclRef(inSubstitutions->genericDecl).As<GenericDecl>();
- for (auto arg : inSubstitutions->args)
+ if (auto genSubst = dynamic_cast<GenericSubstitution*>(inSubstitutions))
{
- result->args.Add(translateVal(arg));
+ RefPtr<GenericSubstitution> result = new GenericSubstitution();
+ result->genericDecl = translateDeclRef(genSubst->genericDecl).As<GenericDecl>();
+ for (auto arg : genSubst->args)
+ {
+ result->args.Add(translateVal(arg));
+ }
+ return result;
}
- return result;
+ else if (auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(inSubstitutions))
+ {
+ RefPtr<ThisTypeSubstitution> result = new ThisTypeSubstitution();
+ result->sourceType = translateVal(result->sourceType);
+ return result;
+ }
+ return nullptr;
}
static Decl* getModifiedDecl(Decl* decl)
@@ -2733,7 +2728,11 @@ struct LoweringVisitor
RefPtr<VarLayout> tryToFindLayout(
Decl* decl)
{
- auto loweredParent = translateDeclRef(decl->ParentDecl);
+ RefPtr<Decl> loweredParent;
+ if (auto genericParentDecl = decl->ParentDecl->As<GenericDecl>())
+ loweredParent = translateDeclRef(genericParentDecl->ParentDecl);
+ else
+ loweredParent = translateDeclRef(decl->ParentDecl);
if (loweredParent)
{
auto layoutMod = loweredParent->FindModifier<ComputedLayoutModifier>();
@@ -3831,7 +3830,7 @@ struct LoweringVisitor
"Vector").As<GenericDecl>();
auto vectorTypeDecl = vectorGenericDecl->inner;
- auto substitutions = new Substitutions();
+ auto substitutions = new GenericSubstitution();
substitutions->genericDecl = vectorGenericDecl.Ptr();
substitutions->args.Add(elementType);
substitutions->args.Add(elementCount);
diff --git a/source/slang/mangle.cpp b/source/slang/mangle.cpp
index cd90a0e41..bde4e12c7 100644
--- a/source/slang/mangle.cpp
+++ b/source/slang/mangle.cpp
@@ -117,10 +117,6 @@ namespace Slang
{
emitQualifiedName(context, declRefType->declRef);
}
- else if (auto assocTypeDeclRefType = dynamic_cast<AssocTypeDeclRefType*>(type))
- {
- emitQualifiedName(context, assocTypeDeclRefType->declRef);
- }
else
{
SLANG_UNEXPECTED("unimplemented case in mangling");
@@ -199,7 +195,7 @@ namespace Slang
// There are two cases here: either we have specializations
// in place for the parent generic declaration, or we don't.
- auto subst = declRef.substitutions;
+ auto subst = declRef.substitutions.As<GenericSubstitution>();
if( subst && subst->genericDecl == parentGenericDeclRef.getDecl() )
{
// This is the case where we *do* have substitutions.
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index 7a7f7fe0e..7e6fd3753 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -9,5 +9,30 @@
<ExpandedItem>rawVal ? ($T1*)((char*)this + rawVal) : ($T1*)0</ExpandedItem>
</Expand>
</Type>
-
+ <Type Name="Slang::DeclRef&lt;*&gt;">
+ <SmartPointer Usage="Minimal">decl ? ($T1*)(decl) : ($T1*)0</SmartPointer>
+ <DisplayString Condition="decl == 0">DeclRef nullptr</DisplayString>
+ <DisplayString Condition="decl != 0">DeclRef {(*(*(Slang::DeclRefBase*)this).decl).nameAndLoc}</DisplayString>
+ <Expand>
+ <ExpandedItem>decl ? ($T1*)(decl) : ($T1*)0</ExpandedItem>
+ <Item Name="[Substitutions]:">"========================="</Item>
+ <LinkedListItems>
+ <HeadPointer>substitutions.pointer</HeadPointer>
+ <NextPointer>outer.pointer</NextPointer>
+ <ValueNode>this</ValueNode>
+ </LinkedListItems>
+ </Expand>
+ </Type>
+ <Type Name="Slang::DeclRefType">
+ <DisplayString>DeclRefType {declRef}</DisplayString>
+ <Expand>
+ <ExpandedItem>declRef</ExpandedItem>
+ </Expand>
+ </Type>
+ <Type Name="Slang::Name">
+ <DisplayString>{{name={(char*)(text.buffer.pointer+1), s}}}</DisplayString>
+ </Type>
+ <Type Name="Slang::NameLoc">
+ <DisplayString>{{name={(char*)((*name).text.buffer.pointer+1), s} loc={loc.raw}}}</DisplayString>
+ </Type>
</AutoVisualizer> \ No newline at end of file
diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h
index 8a22a61d2..3c7e8c5ae 100644
--- a/source/slang/syntax-base-defs.h
+++ b/source/slang/syntax-base-defs.h
@@ -126,31 +126,44 @@ protected:
)
END_SYNTAX_CLASS()
+
// A substitution represents a binding of certain
// type-level variables to concrete argument values
-SYNTAX_CLASS(Substitutions, RefObject)
+ABSTRACT_SYNTAX_CLASS(Substitutions, RefObject)
+
+ // Any further substitutions, relating to outer generic declarations
+ SYNTAX_FIELD(RefPtr<Substitutions>, outer)
+
+ RAW(
+ // Apply a set of substitutions to the bindings in this substitution
+ virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) = 0;
+ // Check if these are equivalent substitutiosn to another set
+ virtual bool Equals(Substitutions* subst) = 0;
+ virtual bool operator == (const Substitutions & subst) = 0;
+ virtual int GetHashCode() const = 0;
+ )
+END_SYNTAX_CLASS()
+
+SYNTAX_CLASS(GenericSubstitution, Substitutions)
// The generic declaration that defines the
// parametesr we are binding to arguments
- DECL_FIELD(GenericDecl*, genericDecl)
+ DECL_FIELD(GenericDecl*, genericDecl)
// The actual values of the arguments
SYNTAX_FIELD(List<RefPtr<Val>>, args)
-
- // Any further substitutions, relating to outer generic declarations
- SYNTAX_FIELD(RefPtr<Substitutions>, outer)
-
+
RAW(
// Apply a set of substitutions to the bindings in this substitution
- RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff);
+ virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
// Check if these are equivalent substitutiosn to another set
- bool Equals(Substitutions* subst);
- bool operator == (const Substitutions & subst)
+ virtual bool Equals(Substitutions* subst) override;
+ virtual bool operator == (const Substitutions & subst) override
{
return Equals(const_cast<Substitutions*>(&subst));
}
- int GetHashCode() const
+ virtual int GetHashCode() const override
{
int rs = 0;
for (auto && v : args)
@@ -163,6 +176,27 @@ SYNTAX_CLASS(Substitutions, RefObject)
)
END_SYNTAX_CLASS()
+SYNTAX_CLASS(ThisTypeSubstitution, Substitutions)
+ // The actual type that provides the lookup scope for an associated type
+ SYNTAX_FIELD(RefPtr<Val>, sourceType)
+
+ RAW(
+ // Apply a set of substitutions to the bindings in this substitution
+ virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+
+ // Check if these are equivalent substitutiosn to another set
+ virtual bool Equals(Substitutions* subst) override;
+ virtual bool operator == (const Substitutions & subst) override
+ {
+ return Equals(const_cast<Substitutions*>(&subst));
+ }
+ virtual int GetHashCode() const override
+ {
+ return sourceType->GetHashCode();
+ }
+ )
+END_SYNTAX_CLASS()
+
ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase)
END_SYNTAX_CLASS()
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 3e38955ba..a3a5fdcb6 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -91,6 +91,8 @@ ABSTRACT_SYNTAX_CLASS(Modifier, SyntaxNodeBase);
ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode);
ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode);
+ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions);
+ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions);
#include "expr-defs.h"
#include "decl-defs.h"
@@ -98,8 +100,6 @@ ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode);
#include "stmt-defs.h"
#include "type-defs.h"
#include "val-defs.h"
-
-
#include "object-meta-end.h"
bool SyntaxClassBase::isSubClassOfImpl(SyntaxClassBase const& super) const
@@ -283,7 +283,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
this, "PtrType").As<GenericDecl>();
auto typeDecl = genericDecl->inner;
- auto substitutions = new Substitutions();
+ auto substitutions = new GenericSubstitution();
substitutions->genericDecl = genericDecl.Ptr();
substitutions->args.Add(valueType);
@@ -414,38 +414,57 @@ void Type::accept(IValVisitor* visitor, void* extra)
// search for a substitution that might apply to us
for (auto s = subst; s; s = s->outer.Ptr())
{
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = s->genericDecl;
- if (genericDecl != genericTypeParamDecl->ParentDecl)
- continue;
-
- int index = 0;
- for (auto m : genericDecl->Members)
+ if (auto genericSubst = dynamic_cast<GenericSubstitution*>(s))
{
- if (m.Ptr() == genericTypeParamDecl)
- {
- // We've found it, so return the corresponding specialization argument
- (*ioDiff)++;
- return s->args[index];
- }
- else if(auto typeParam = m.As<GenericTypeParamDecl>())
+ // the generic decl associated with the substitution list must be
+ // the generic decl that declared this parameter
+ auto genericDecl = genericSubst->genericDecl;
+ if (genericDecl != genericTypeParamDecl->ParentDecl)
+ continue;
+
+ int index = 0;
+ for (auto m : genericDecl->Members)
{
- index++;
- }
- else if(auto valParam = m.As<GenericValueParamDecl>())
- {
- index++;
+ if (m.Ptr() == genericTypeParamDecl)
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return genericSubst->args[index];
+ }
+ else if (auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ index++;
+ }
+ else if (auto valParam = m.As<GenericValueParamDecl>())
+ {
+ index++;
+ }
+ else
+ {
+ }
}
- else
+ }
+
+ }
+ }
+ // the second case we care about is when this decl type refers to an associatedtype decl
+ // we want to replace it with the actual associated type
+ else if (auto assocTypeDecl = dynamic_cast<AssocTypeDecl*>(declRef.getDecl()))
+ {
+ // search for a substitution that might apply to us
+ for (auto s = subst; s; s = s->outer.Ptr())
+ {
+ if (auto thisTypeSubst = dynamic_cast<ThisTypeSubstitution*>(s))
+ {
+ if (auto aggTypeDeclRef = thisTypeSubst->sourceType.As<DeclRefType>()->declRef.As<AggTypeDecl>())
{
+ Decl * targetType = nullptr;
+ if (aggTypeDeclRef.getDecl()->memberDictionary.TryGetValue(assocTypeDecl->getName(), targetType))
+ return DeclRefType::Create(this->getSession(), DeclRef<Decl>(targetType, aggTypeDeclRef.substitutions));
}
}
-
}
}
-
-
int diff = 0;
DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff);
@@ -486,10 +505,25 @@ void Type::accept(IValVisitor* visitor, void* extra)
// we will construct a default specialization at the use
// site if needed.
- if( auto genericParent = declRef.GetParent().As<GenericDecl>() )
+ if (auto genericParent = declRef.GetParent().As<GenericDecl>())
{
auto subst = declRef.substitutions;
- if( !subst || subst->genericDecl != genericParent.decl )
+ // try find a substitution targeting this generic decl
+ bool substFound = false;
+ while (subst)
+ {
+ if (auto genSubst = dynamic_cast<GenericSubstitution*>(subst.Ptr()))
+ {
+ if (genSubst->genericDecl == genericParent.decl)
+ {
+ substFound = true;
+ break;
+ }
+ }
+ subst = subst->outer;
+ }
+ // we did not find an existing substituion, create a default one
+ if (!substFound)
{
declRef.substitutions = createDefaultSubstitutions(
session,
@@ -507,7 +541,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
}
else if (auto magicMod = declRef.getDecl()->FindModifier<MagicTypeModifier>())
{
- Substitutions* subst = declRef.substitutions.Ptr();
+ GenericSubstitution* subst = declRef.substitutions.As<GenericSubstitution>().Ptr();
if (magicMod->name == "SamplerState")
{
@@ -761,7 +795,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
bool NamedExpressionType::EqualsImpl(Type * /*type*/)
{
SLANG_UNEXPECTED("unreachable");
- return false;
}
Type* NamedExpressionType::CreateCanonicalType()
@@ -772,7 +805,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
int NamedExpressionType::GetHashCode()
{
SLANG_UNEXPECTED("unreachable");
- return 0;
}
// FuncType
@@ -910,7 +942,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
int TypeType::GetHashCode()
{
SLANG_UNEXPECTED("unreachable");
- return 0;
}
// GenericDeclRefType
@@ -940,125 +971,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
return this;
}
- // AssocTypeDeclRefType
-
- String AssocTypeDeclRefType::ToString()
- {
- // TODO: what is appropriate here?
- return "<AssocType>";
- }
-
- bool AssocTypeDeclRefType::EqualsImpl(Type * type)
- {
- if (auto assocTypeDeclRefType = type->As<AssocTypeDeclRefType>())
- {
- return declRef.Equals(assocTypeDeclRefType->declRef);
- }
- return false;
- }
-
- RefPtr<Val> AssocTypeDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
- {
- if (!sourceType)
- return this;
- auto substSourceType = sourceType->SubstituteImpl(subst, ioDiff);
- if (auto parentDeclRefType = substSourceType.As<DeclRefType>())
- {
- auto parentDeclRef = parentDeclRefType->declRef;
- DeclRef<AggTypeDecl> newParentDeclRef = parentDeclRef.As<AggTypeDecl>();
- // search for a substitution that might apply to us
- for (auto s = subst; s; s = s->outer.Ptr())
- {
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = s->genericDecl;
- if (genericDecl != parentDeclRef.getDecl()->ParentDecl)
- continue;
- int index = 0;
- for (auto m : genericDecl->Members)
- {
- if (m.Ptr() == parentDeclRef.getDecl())
- {
- // We've found it, so return the corresponding specialization argument
- (*ioDiff)++;
- if (auto declRef = s->args[index].As<DeclRefType>())
- {
- newParentDeclRef = (*declRef).declRef.As<AggTypeDecl>();
- goto searchEnd;
- }
- }
- else if (auto typeParam = m.As<GenericTypeParamDecl>())
- {
- index++;
- }
- else if (auto valParam = m.As<GenericValueParamDecl>())
- {
- index++;
- }
- else
- {
- }
- }
- }
- searchEnd:
- if (newParentDeclRef)
- {
- Decl* targetTypeDecl = nullptr;
- if (newParentDeclRef.getDecl()->memberDictionary.TryGetValue(this->GetDeclRef().decl->getName(), targetTypeDecl))
- {
- if (auto typeDefDecl = targetTypeDecl->As<TypeDefDecl>())
- return GetType(DeclRef<TypeDefDecl>(typeDefDecl, subst));
- else
- return DeclRefType::Create(this->getSession(), DeclRef<Decl>(targetTypeDecl, subst));
- }
- }
- }
-
- return this;
- }
-
- int AssocTypeDeclRefType::GetHashCode()
- {
- return declRef.GetHashCode();
- }
-
- Type* AssocTypeDeclRefType::CreateCanonicalType()
- {
- return this;
- }
-
- // GenericConstraintDeclRefType
-
- String GenericConstraintDeclRefType::ToString()
- {
- // TODO: what is appropriate here?
- return "<GenericConstraintType>";
- }
-
- bool GenericConstraintDeclRefType::EqualsImpl(Type * type)
- {
- if (auto other = type->As<GenericConstraintDeclRefType>())
- {
- return supType->Equals(other->supType) && subType->Equals(other->subType);
- }
- return false;
- }
-
- RefPtr<Val> GenericConstraintDeclRefType::SubstituteImpl(Substitutions* subst, int* ioDiff)
- {
- return subType->SubstituteImpl(subst, ioDiff);
- }
-
- int GenericConstraintDeclRefType::GetHashCode()
- {
- return combineHash(subType.GetHashCode(), supType.GetHashCode());
- }
-
- Type* GenericConstraintDeclRefType::CreateCanonicalType()
- {
- return this;
- }
-
// ArithmeticExpressionType
// VectorExpressionType
@@ -1091,24 +1003,24 @@ void Type::accept(IValVisitor* visitor, void* extra)
Type* MatrixExpressionType::getElementType()
{
- return this->declRef.substitutions->args[0].As<Type>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr();
}
IntVal* MatrixExpressionType::getRowCount()
{
- return this->declRef.substitutions->args[1].As<IntVal>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[1].As<IntVal>().Ptr();
}
IntVal* MatrixExpressionType::getColumnCount()
{
- return this->declRef.substitutions->args[2].As<IntVal>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[2].As<IntVal>().Ptr();
}
// PtrTypeBase
Type* PtrTypeBase::getValueType()
{
- return this->declRef.substitutions->args[0].As<Type>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr();
}
// GenericParamIntVal
@@ -1137,31 +1049,34 @@ void Type::accept(IValVisitor* visitor, void* extra)
// search for a substitution that might apply to us
for (auto s = subst; s; s = s->outer.Ptr())
{
- // the generic decl associated with the substitution list must be
- // the generic decl that declared this parameter
- auto genericDecl = s->genericDecl;
- if (genericDecl != declRef.getDecl()->ParentDecl)
- continue;
-
- int index = 0;
- for (auto m : genericDecl->Members)
+ if (auto genSubst = dynamic_cast<GenericSubstitution*>(s))
{
- if (m.Ptr() == declRef.getDecl())
- {
- // We've found it, so return the corresponding specialization argument
- (*ioDiff)++;
- return s->args[index];
- }
- else if(auto typeParam = m.As<GenericTypeParamDecl>())
- {
- index++;
- }
- else if(auto valParam = m.As<GenericValueParamDecl>())
- {
- index++;
- }
- else
+ // the generic decl associated with the substitution list must be
+ // the generic decl that declared this parameter
+ auto genericDecl = genSubst->genericDecl;
+ if (genericDecl != declRef.getDecl()->ParentDecl)
+ continue;
+
+ int index = 0;
+ for (auto m : genericDecl->Members)
{
+ if (m.Ptr() == declRef.getDecl())
+ {
+ // We've found it, so return the corresponding specialization argument
+ (*ioDiff)++;
+ return genSubst->args[index];
+ }
+ else if (auto typeParam = m.As<GenericTypeParamDecl>())
+ {
+ index++;
+ }
+ else if (auto valParam = m.As<GenericValueParamDecl>())
+ {
+ index++;
+ }
+ else
+ {
+ }
}
}
}
@@ -1172,12 +1087,12 @@ void Type::accept(IValVisitor* visitor, void* extra)
// Substitutions
- RefPtr<Substitutions> Substitutions::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ RefPtr<Substitutions> GenericSubstitution::SubstituteImpl(Substitutions* subst, int* ioDiff)
{
if (!this) return nullptr;
int diff = 0;
- auto outerSubst = outer->SubstituteImpl(subst, &diff);
+ auto outerSubst = outer ? outer->SubstituteImpl(subst, &diff) : nullptr;
List<RefPtr<Val>> substArgs;
for (auto a : args)
@@ -1188,35 +1103,73 @@ void Type::accept(IValVisitor* visitor, void* extra)
if (!diff) return this;
(*ioDiff)++;
- auto substSubst = new Substitutions();
+ auto substSubst = new GenericSubstitution();
substSubst->genericDecl = genericDecl;
substSubst->args = substArgs;
return substSubst;
}
- bool Substitutions::Equals(Substitutions* subst)
+ bool GenericSubstitution::Equals(Substitutions* subst)
{
// both must be NULL, or non-NULL
if (!this || !subst)
return !this && !subst;
-
- if (genericDecl != subst->genericDecl)
+ auto genericSubst = dynamic_cast<GenericSubstitution*>(subst);
+ if (!genericSubst)
+ return false;
+ if (genericDecl != genericSubst->genericDecl)
return false;
UInt argCount = args.Count();
- SLANG_RELEASE_ASSERT(args.Count() == subst->args.Count());
+ SLANG_RELEASE_ASSERT(args.Count() == genericSubst->args.Count());
for (UInt aa = 0; aa < argCount; ++aa)
{
- if (!args[aa]->EqualsVal(subst->args[aa].Ptr()))
+ if (!args[aa]->EqualsVal(genericSubst->args[aa].Ptr()))
return false;
}
+ if (!outer)
+ return !subst->outer;
+
if (!outer->Equals(subst->outer.Ptr()))
return false;
return true;
}
+ RefPtr<Substitutions> ThisTypeSubstitution::SubstituteImpl(Substitutions* subst, int* ioDiff)
+ {
+ if (!this) return nullptr;
+
+ int diff = 0;
+ auto outerSubst = outer->SubstituteImpl(subst, &diff);
+
+ auto newSourceType = sourceType->SubstituteImpl(subst, ioDiff);
+ if (!diff) return this;
+
+ (*ioDiff)++;
+ auto substSubst = new ThisTypeSubstitution();
+ substSubst->sourceType = newSourceType;
+ substSubst->outer = outerSubst;
+ return substSubst;
+ }
+
+ bool ThisTypeSubstitution::Equals(Substitutions* subst)
+ {
+ // both must be NULL, or non-NULL
+ if (!this || !subst)
+ return !this && !subst;
+ auto thisSubst = dynamic_cast<ThisTypeSubstitution*>(subst);
+ if (!thisSubst)
+ return false;
+ if (!thisSubst->sourceType->EqualsVal(sourceType))
+ return false;
+ if (!outer->Equals(subst->outer.Ptr()))
+ return false;
+ return true;
+ }
+
+
// DeclRefBase
@@ -1248,8 +1201,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
return expr;
SLANG_UNIMPLEMENTED_X("generic substitution into expressions");
-
- return expr;
}
@@ -1277,7 +1228,8 @@ void Type::accept(IValVisitor* visitor, void* extra)
{
if (decl != declRef.decl)
return false;
-
+ if (!substitutions)
+ return !declRef.substitutions;
if (!substitutions->Equals(declRef.substitutions.Ptr()))
return false;
@@ -1298,7 +1250,8 @@ void Type::accept(IValVisitor* visitor, void* extra)
if (auto parentGeneric = dynamic_cast<GenericDecl*>(parentDecl))
{
- if (substitutions && substitutions->genericDecl == parentDecl)
+ auto genSubst = substitutions.As<GenericSubstitution>();
+ if (genSubst && genSubst->genericDecl == parentDecl)
{
// We strip away the specializations that were applied to
// the parent, since we were asked for a reference *to* the parent.
@@ -1427,7 +1380,6 @@ void Type::accept(IValVisitor* visitor, void* extra)
else
{
SLANG_UNEXPECTED("unhandled syntax class name");
- return nullptr;
}
}
@@ -1437,12 +1389,12 @@ void Type::accept(IValVisitor* visitor, void* extra)
Type* HLSLPatchType::getElementType()
{
- return this->declRef.substitutions->args[0].As<Type>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[0].As<Type>().Ptr();
}
IntVal* HLSLPatchType::getElementCount()
{
- return this->declRef.substitutions->args[1].As<IntVal>().Ptr();
+ return this->declRef.substitutions.As<GenericSubstitution>()->args[1].As<IntVal>().Ptr();
}
// Constructors for types
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index 2052e1eb8..60f519c82 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -489,59 +489,4 @@ protected:
virtual int GetHashCode() override;
virtual Type* CreateCanonicalType() override;
)
-END_SYNTAX_CLASS()
-
-// The "type" of an expression that references a asscoiated type decl (via 'assoctype' keyword).
-SYNTAX_CLASS(AssocTypeDeclRefType, Type)
- DECL_FIELD(DeclRef<AssocTypeDecl>, declRef)
- DECL_FIELD(RefPtr<Type>, sourceType)
- RAW(
- AssocTypeDeclRefType()
- {}
- AssocTypeDeclRefType(
- DeclRef<AssocTypeDecl> declRef)
- : declRef(declRef)
- {}
-
-
- DeclRef<AssocTypeDecl> const& GetDeclRef() const { return declRef; }
-
- virtual String ToString() override;
-
- protected:
- virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
- virtual bool EqualsImpl(Type * type) override;
- virtual int GetHashCode() override;
- virtual Type* CreateCanonicalType() override;
- )
-END_SYNTAX_CLASS()
-
-// The "type" of an generic constraint, which wraps both the sub and sup (interface) type
-// the sub type can be used in associated type substitution in later type evaluation
-SYNTAX_CLASS(GenericConstraintDeclRefType, Type)
- DECL_FIELD(RefPtr<Type>, subType)
- DECL_FIELD(RefPtr<Type>, supType)
-RAW(
- GenericConstraintDeclRefType()
- {}
- GenericConstraintDeclRefType(Session* session,
- RefPtr<Type> sub,
- RefPtr<Type> sup)
- : subType(sub), supType(sup)
- {
- setSession(session);
- }
-
-
- RefPtr<Type> const& GetSupType() const { return supType; }
- RefPtr<Type> const& GetSubType() const { return subType; }
-
- virtual String ToString() override;
-
- protected:
- virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
- virtual bool EqualsImpl(Type * type) override;
- virtual int GetHashCode() override;
- virtual Type* CreateCanonicalType() override;
-)
END_SYNTAX_CLASS() \ No newline at end of file
diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp
index f129d15e0..d5bacaa36 100644
--- a/source/slang/vm.cpp
+++ b/source/slang/vm.cpp
@@ -513,8 +513,8 @@ void computeTypeSizeAlign(
break;
default:
- SLANG_UNIMPLEMENTED_X("type sizing");
impl->size = 0;
+ SLANG_UNIMPLEMENTED_X("type sizing");
break;
}
@@ -528,7 +528,7 @@ void computeTypeSizeAlign(
}
VMType getType(
- VM* vm,
+ VM* /*vm*/,
VMTypeImpl* typeImpl)
{
// TODO: need to look up an existing type that matches...
@@ -587,7 +587,7 @@ VMType loadVMType(
VMTypeImpl* impl = (VMTypeImpl*) alloca(size);
memset(impl, 0, size);
impl->op = bcType->op;
- impl->argCount = argCount;
+ impl->argCount = (uint32_t)argCount;
VMVal* args = (VMVal*) (impl + 1);
for(UInt aa = 0; aa < argCount; ++aa)
@@ -597,14 +597,11 @@ VMType loadVMType(
return getType(vmModule->vm, impl);
}
-
- SLANG_UNEXPECTED("unimplemented");
- return VMType();
break;
}
}
-void* allocateImpl(VM* vm, UInt size, UInt align)
+void* allocateImpl(VM* /*vm*/, UInt size, UInt /*align*/)
{
void* ptr = malloc(size);
memset(ptr, 0, size);
@@ -666,7 +663,7 @@ void* loadVMSymbol(
VMModule* loadVMModuleInstance(
VM* vm,
void const* bytecode,
- size_t bytecodeSize)
+ size_t /*bytecodeSize*/)
{
BCHeader* bcHeader = (BCHeader*) bytecode;
@@ -732,14 +729,14 @@ void* findGlobalSymbolPtr(
continue;
if(strcmp(symbolName, name) == 0)
- return getGlobalPtr(module, ss);
+ return getGlobalPtr(module, (uint32_t)ss);
}
return nullptr;
}
VMThread* createThread(
- VM* vm)
+ VM* /*vm*/)
{
VMThread* thread = new VMThread();
thread->frame = nullptr;
@@ -863,7 +860,7 @@ void resumeThread(
case kIROp_BufferStore:
{
VMType resultType = decodeType(frame, &ip);
- UInt argCount = decodeUInt(&ip);
+ /*UInt argCount =*/ decodeUInt(&ip);
char* bufferData = decodeOperand<char*>(frame, &ip);
uint32_t index = decodeOperand<uint32_t>(frame, &ip);
@@ -944,10 +941,9 @@ void resumeThread(
case kIROp_ReturnVal:
{
VMType instType = decodeType(frame, &ip);
- UInt argCount = decodeUInt(&ip);
+ /*UInt argCount =*/ decodeUInt(&ip);
void* argPtr = decodeOperandPtr<void>(frame, &ip);
- VMFrame* oldFrame = frame;
VMFrame* newFrame = frame->parent;
vmThread->frame = newFrame;
@@ -980,7 +976,7 @@ void resumeThread(
Int destinationBlock = decodeSInt(&ip);
for( UInt aa = 2; aa < argCount; ++aa )
{
- void* argPtr = decodeOperandPtr<void>(frame, &ip);
+ decodeOperandPtr<void>(frame, &ip);
}
// TODO: we need to deal with the case of
@@ -1006,7 +1002,7 @@ void resumeThread(
Int falseBlockID = decodeSInt(&ip);
for( UInt aa = 4; aa < argCount; ++aa )
{
- void* argPtr = decodeOperandPtr<void>(frame, &ip);
+ decodeOperandPtr<void>(frame, &ip);
}
Int destinationBlock = *condition ? trueBlockID : falseBlockID;
@@ -1025,7 +1021,7 @@ void resumeThread(
// knowing too much about an instruction...
VMType resultType = decodeType(frame, &ip);
- UInt argCount = decodeUInt(&ip);
+ /*UInt argCount =*/ decodeUInt(&ip);
void* argPtrs[16] = { 0 };
auto leftOpnd = decodeOperandPtrAndType(frame, &ip);
auto type = leftOpnd.type;
@@ -1050,7 +1046,7 @@ void resumeThread(
case kIROp_Mul:
{
VMType type = decodeType(frame, &ip);
- UInt argCount = decodeUInt(&ip);
+ /*UInt argCount =*/ decodeUInt(&ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
@@ -1072,7 +1068,7 @@ void resumeThread(
case kIROp_Sub:
{
VMType type = decodeType(frame, &ip);
- UInt argCount = decodeUInt(&ip);
+ /*UInt argCount =*/ decodeUInt(&ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
diff --git a/tests/compute/assoctype-complex.slang b/tests/compute/assoctype-complex.slang
index de3f1a103..f29d231b6 100644
--- a/tests/compute/assoctype-complex.slang
+++ b/tests/compute/assoctype-complex.slang
@@ -30,30 +30,18 @@ struct Simple : ISimple
return v0.sub(4, v1.sub(1,2));
}
};
-/*
+
__generic<T:ISimple>
T.U.V test(T simple, T.U v0, T.U v1)
{
return simple.add(v0, v1);
}
-__generic<T:__BuiltinArithmeticType>
-T test(T v0, T v1)
-{
- return v0 + v1;
-}
-*/
-__generic<T:__BuiltinFloatingPointType>
-T test(T v0, T v1)
-{
- return T(3.0);
-}
[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
{
- //Simple s;
- //Val v0, v1;
- //float outVal = test(s, v0, v1); // == 1.0
- float outVal = test<float>(1.0, 2.0);
+ Simple s;
+ Val v0, v1;
+ float outVal = test(s, v0, v1); // == 1.0
outputBuffer[dispatchThreadID.x] = outVal;
} \ No newline at end of file
diff --git a/tests/compute/generics-constraint1.slang b/tests/compute/generics-constraint1.slang
index ff90c1cc9..aa8d398e8 100644
--- a/tests/compute/generics-constraint1.slang
+++ b/tests/compute/generics-constraint1.slang
@@ -6,7 +6,7 @@ RWStructuredBuffer<float> outputBuffer;
__generic<T:__BuiltinFloatingPointType>
T test(T v0, T v1)
{
- return T(3.0);
+ return v0;
}
[numthreads(4, 1, 1)]
diff --git a/tests/compute/generics-constructor.slang b/tests/compute/generics-constructor.slang
new file mode 100644
index 000000000..c7473cc8b
--- /dev/null
+++ b/tests/compute/generics-constructor.slang
@@ -0,0 +1,17 @@
+//TEST(smoke, compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<float> outputBuffer;
+
+__generic<T:__BuiltinFloatingPointType>
+T test(T v0, T v1)
+{
+ return T(3.0);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float outVal = test<float>(1.0, 2.0);
+ outputBuffer[dispatchThreadID.x] = outVal;
+} \ No newline at end of file