summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--examples/hello/hello.cpp12
-rw-r--r--slang.h39
-rw-r--r--source/slang/bytecode.cpp53
-rw-r--r--source/slang/bytecode.h22
-rw-r--r--source/slang/check.cpp128
-rw-r--r--source/slang/compiler.cpp207
-rw-r--r--source/slang/compiler.h69
-rw-r--r--source/slang/diagnostic-defs.h8
-rw-r--r--source/slang/emit.cpp167
-rw-r--r--source/slang/emit.h8
-rw-r--r--source/slang/ir-insts.h27
-rw-r--r--source/slang/ir.cpp487
-rw-r--r--source/slang/ir.h20
-rw-r--r--source/slang/lower-to-ir.cpp177
-rw-r--r--source/slang/lower-to-ir.h7
-rw-r--r--source/slang/lower.cpp45
-rw-r--r--source/slang/mangle.h2
-rw-r--r--source/slang/options.cpp15
-rw-r--r--source/slang/parameter-binding.cpp69
-rw-r--r--source/slang/parameter-binding.h4
-rw-r--r--source/slang/profile-defs.h2
-rw-r--r--source/slang/reflection.cpp36
-rw-r--r--source/slang/slang.cpp192
-rw-r--r--source/slang/syntax-visitors.h1
-rw-r--r--source/slang/type-layout.cpp1
-rw-r--r--source/slang/vm.cpp6
-rw-r--r--source/slangc/main.cpp19
-rw-r--r--tests/ir/loop.slang2
-rw-r--r--tools/eval-test/main.cpp10
-rw-r--r--tools/render-test/render-d3d11.cpp255
-rw-r--r--tools/render-test/slang-support.cpp16
31 files changed, 1270 insertions, 836 deletions
diff --git a/examples/hello/hello.cpp b/examples/hello/hello.cpp
index 8e48b3c13..2d7bb81b5 100644
--- a/examples/hello/hello.cpp
+++ b/examples/hello/hello.cpp
@@ -100,8 +100,8 @@ HRESULT initialize( ID3D11Device* dxDevice )
char const* vertexProfileName = "vs_4_0";
char const* fragmentProfileName = "ps_4_0";
- spAddEntryPoint(slangRequest, translationUnitIndex, vertexEntryPointName, spFindProfile(slangSession, vertexProfileName));
- spAddEntryPoint(slangRequest, translationUnitIndex, fragmentEntryPointName, spFindProfile(slangSession, fragmentProfileName));
+ int vertexIndex = spAddEntryPoint(slangRequest, translationUnitIndex, vertexEntryPointName, spFindProfile(slangSession, vertexProfileName));
+ int fragmentIndex = spAddEntryPoint(slangRequest, translationUnitIndex, fragmentEntryPointName, spFindProfile(slangSession, fragmentProfileName));
int compileErr = spCompile(slangRequest);
if(auto diagnostics = spGetDiagnosticOutput(slangRequest))
@@ -114,17 +114,17 @@ HRESULT initialize( ID3D11Device* dxDevice )
return E_FAIL;
}
- char const* translatedCode = spGetTranslationUnitSource(slangRequest, translationUnitIndex);
-
+ char const* vertexCode = spGetEntryPointSource(slangRequest, vertexIndex);
+ char const* fragmentCode = spGetEntryPointSource(slangRequest, fragmentIndex);
// TODO(tfoley): Query the required constant-buffer size
int constantBufferSize = 16 * sizeof(float);
// Compile the generated HLSL code
- ID3DBlob* dxVertexShaderBlob = compileHLSLShader(translatedCode, vertexEntryPointName, vertexProfileName);
+ ID3DBlob* dxVertexShaderBlob = compileHLSLShader(vertexCode, vertexEntryPointName, vertexProfileName);
if(!dxVertexShaderBlob) return E_FAIL;
- ID3DBlob* dxPixelShaderBlob = compileHLSLShader(translatedCode, fragmentEntryPointName, fragmentProfileName);
+ ID3DBlob* dxPixelShaderBlob = compileHLSLShader(fragmentCode, fragmentEntryPointName, fragmentProfileName);
if(!dxPixelShaderBlob) return E_FAIL;
HRESULT hr = S_OK;
diff --git a/slang.h b/slang.h
index 231ef7472..ff79d698b 100644
--- a/slang.h
+++ b/slang.h
@@ -89,7 +89,20 @@ extern "C"
SLANG_SPIRV_ASM,
SLANG_DXBC,
SLANG_DXBC_ASM,
- SLANG_IR,
+ };
+
+ /* A "container format" describes the way that the outputs
+ for multiple files, entry points, targets, etc. should be
+ combined into a single artifact for output. */
+ typedef int SlangContainerFormat;
+ enum
+ {
+ /* Don't generate a container. */
+ SLANG_CONTAINER_FORMAT_NONE,
+
+ /* Generate a container in the `.slang-module` format,
+ which includes reflection information, compiled kernels, etc. */
+ SLANG_CONTAINER_FORMAT_SLANG_MODULE,
};
typedef int SlangPassThrough;
@@ -231,6 +244,20 @@ extern "C"
SlangCompileRequest* request,
int target);
+ /*!
+ @brief Add a code-generation target to be used.
+ */
+ SLANG_API void spAddCodeGenTarget(
+ SlangCompileRequest* request,
+ SlangCompileTarget target);
+
+ /*!
+ @brief Set the container format to be used for binary output.
+ */
+ SLANG_API void spSetOutputContainerFormat(
+ SlangCompileRequest* request,
+ SlangContainerFormat format);
+
SLANG_API void spSetPassThrough(
SlangCompileRequest* request,
SlangPassThrough passThrough);
@@ -389,6 +416,16 @@ extern "C"
int entryPointIndex,
size_t* outSize);
+ /** Get the output bytecode associated with an entire compile request.
+
+ The lifetime of the output pointer is the same as `request`.
+ */
+ SLANG_API void const* spGetCompileRequestCode(
+ SlangCompileRequest* request,
+ size_t* outSize);
+
+
+
typedef struct SlangVM SlangVM;
typedef struct SlangVMModule SlangVMModule;
typedef struct SlangVMFunc SlangVMFunc;
diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp
index 8d8d609b9..085b7303d 100644
--- a/source/slang/bytecode.cpp
+++ b/source/slang/bytecode.cpp
@@ -936,6 +936,12 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
BytecodeGenerationContext* context,
IRModule* irModule)
{
+ if (!irModule)
+ {
+ // Not IR module? Then return a null pointer.
+ return BytecodeGenerationPtr<BCModule>();
+ }
+
// A module in the bytecode is mostly just a list of the
// global symbols in the module.
//
@@ -1032,9 +1038,9 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
return bcModule;
}
-void generateBytecodeStream(
+void generateBytecodeContainer(
BytecodeGenerationContext* context,
- IRModule* irModule)
+ CompileRequest* compileReq)
{
// Header must be the very first thing in the bytecode stream
BytecodeGenerationPtr<BCHeader> header = allocate<BCHeader>(context);
@@ -1042,9 +1048,49 @@ void generateBytecodeStream(
memcpy(header->magic, "slang\0bc", sizeof(header->magic));
header->version = 0;
- header->module = generateBytecodeForModule(context, irModule);
+ // TODO: Need to generate BC representation of all the public/exported
+ // declrations in the translation units, so that they can be used to
+ // resolve depenencies downstream.
+
+ // TODO: Need to dump BC representation of compiled kernel codes
+ // for each specified code-generation target.
+
+ UInt translationUnitCount = compileReq->translationUnits.Count();
+
+ List<BytecodeGenerationPtr<BCModule>> bcModulesList;
+ for (auto translationUnitReq : compileReq->translationUnits)
+ {
+ auto bcModule = generateBytecodeForModule(context, translationUnitReq->irModule);
+ bcModulesList.Add(bcModule);
+ }
+
+ UInt bcModuleCount = bcModulesList.Count();
+ header->moduleCount = bcModuleCount;
+
+ auto bcModules = allocateArray<BCPtr<BCModule>>(context, bcModuleCount);
+ header->modules = bcModules;
+ for(UInt ii = 0; ii < bcModuleCount; ++ii)
+ {
+ bcModules[ii] = bcModulesList[ii];
+ }
+}
+
+void generateBytecodeForCompileRequest(
+ CompileRequest* compileReq)
+{
+ SharedBytecodeGenerationContext sharedContext;
+
+ BytecodeGenerationContext context;
+ context.shared = &sharedContext;
+
+ generateBytecodeContainer(&context, compileReq);
+
+ compileReq->generatedBytecode = sharedContext.bytecode;
}
+// TODO: Need to support IR emit at the whole-module/compile-request
+// level, and not just for individual entry points.
+#if 0
List<uint8_t> emitSlangIRForEntryPoint(
EntryPointRequest* entryPoint)
{
@@ -1072,5 +1118,6 @@ List<uint8_t> emitSlangIRForEntryPoint(
return sharedContext.bytecode;
}
+#endif
} // namespace Slang
diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h
index 1ea16406f..f38007ba9 100644
--- a/source/slang/bytecode.h
+++ b/source/slang/bytecode.h
@@ -225,11 +225,27 @@ struct BCHeader
// kinds of data without having to revise
// the schema here.
- // The bytecode representation of the module
- BCPtr<BCModule> module;
+ // TODO: should include AST declaration structure
+ // here, which can be used for refleciton, and
+ // also loaded to resolve dependencies when
+ // compiling other modules.
+
+ // TODO: Include the original entry point requests?
+
+ // Zero or more IR modules, corresponding to
+ // the translation units of the original compile
+ // request.
+ uint32_t moduleCount;
+ BCPtr<BCPtr<BCModule>> modules;
+
+ // TODO: should enumerate targets here, and
+ // include reflection layout info + compiled
+ // entry points for each target.
};
-
+struct CompileRequest;
+void generateBytecodeForCompileRequest(
+ CompileRequest* compileReq);
}
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index e8fd16b94..f12e7e55d 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -5622,6 +5622,115 @@ namespace Slang
}
};
+ bool isPrimaryDecl(
+ CallableDecl* decl)
+ {
+ assert(decl);
+ return (!decl->primaryDecl) || (decl == decl->primaryDecl);
+ }
+
+ void validateEntryPoint(
+ EntryPointRequest* entryPoint)
+ {
+ // The first step in validating the entry point is to find
+ // the (unique) function declaration that matches its name.
+
+ auto translationUnit = entryPoint->getTranslationUnit();
+ auto sink = &entryPoint->compileRequest->mSink;
+ auto translationUnitSyntax = translationUnit->SyntaxNode;
+
+
+ // Make sure we've got a query-able member dictionary
+ buildMemberDictionary(translationUnitSyntax);
+
+ // We will look up any global-scope declarations in the translation
+ // unit that match the name of our entry point.
+ Decl* firstDeclWithName = nullptr;
+ if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, firstDeclWithName) )
+ {
+ // If there doesn't appear to be any such declaration, then
+ // we need to diagnose it as an error, and then bail out.
+ sink->diagnose(translationUnitSyntax, Diagnostics::entryPointFunctionNotFound, entryPoint->name);
+ return;
+ }
+
+ // We found at least one global-scope declaration with the right name,
+ // but (1) it might not be a function, and (2) there might be
+ // more than one function.
+ //
+ // We'll walk the linked list of declarations with the same name,
+ // to see what we find. Along the way we'll keep track of the
+ // first function declaration we find, if any:
+ FuncDecl* entryPointFuncDecl = nullptr;
+ for(auto ee = firstDeclWithName; ee; ee = ee->nextInContainerWithSameName)
+ {
+ // Is this declaration a function?
+ if (auto funcDecl = dynamic_cast<FuncDecl*>(ee))
+ {
+ // Skip non-primary declarations, so that
+ // we don't give an error when an entry
+ // point is forward-declared.
+ if (!isPrimaryDecl(funcDecl))
+ continue;
+
+ // is this the first one we've seen?
+ if (!entryPointFuncDecl)
+ {
+ // If so, this is a candidate to be
+ // the entry point function.
+ entryPointFuncDecl = funcDecl;
+ }
+ else
+ {
+ // Uh-oh! We've already seen a function declaration with this
+ // name before, so the whole thing is ambiguous. We need
+ // to diagnose and bail out.
+
+ sink->diagnose(translationUnitSyntax, Diagnostics::ambiguousEntryPoint, entryPoint->name);
+
+ // List all of the declarations that the user *might* mean
+ for (auto ff = firstDeclWithName; ff; ff = ff->nextInContainerWithSameName)
+ {
+ if (auto candidate = dynamic_cast<FuncDecl*>(ff))
+ {
+ sink->diagnose(candidate, Diagnostics::entryPointCandidate, candidate->getName());
+ }
+ }
+
+ // Bail out.
+ return;
+ }
+ }
+ }
+
+ // Did we find a function declaration in our search?
+ if(!entryPointFuncDecl)
+ {
+ // If not, then we need to diagnose the error.
+ // For convenience, we will point to the first
+ // declaration with the right name, that wasn't a function.
+ sink->diagnose(firstDeclWithName, Diagnostics::entryPointSymbolNotAFunction, entryPoint->name);
+ return;
+ }
+
+ // TODO: it is possible that the entry point was declared with
+ // profile or target overloading. Is there anything that we need
+ // to do at this point to filter out declarations that aren't
+ // relevant to the selected profile for the entry point?
+
+ // Phew, we have at least found a suitable decl.
+ // Let's record that in the entry-point request so
+ // that we don't have to re-do this effort again later.
+ entryPoint->decl = entryPointFuncDecl;
+
+ // TODO: after all that work, we are now in a position to start
+ // validating the declaration itself. E.g., we should check if
+ // the declared input/output parameters have suitable semantics,
+ // if they are of types that are appropriate to the stage, etc.
+ }
+
+
+
void checkTranslationUnit(
TranslationUnitRequest* translationUnit)
{
@@ -5630,9 +5739,28 @@ namespace Slang
translationUnit->compileRequest,
translationUnit);
+ // Apply the visitor to do the main semantic
+ // checking that is required on all declarations
+ // in the translation unit.
visitor.checkDecl(translationUnit->SyntaxNode);
+
+ // Next, do follow-up validation on any entry
+ // points that the user declared via API or
+ // command line, to ensure that they meet
+ // requirements.
+ //
+ // Note: We may eventually have syntax to
+ // identify entry points via a modifier on
+ // declarations, and in this case they should
+ // probably get validated as part of orindary
+ // checking above.
+ for (auto entryPoint : translationUnit->entryPoints)
+ {
+ validateEntryPoint(entryPoint);
+ }
}
+
//
// Get the type to use when referencing a declaration
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 10755370a..8cf801a79 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -3,8 +3,10 @@
#include "../core/basic.h"
#include "../core/platform.h"
#include "../core/slang-io.h"
+#include "bytecode.h"
#include "compiler.h"
#include "lexer.h"
+#include "lower-to-ir.h"
#include "parameter-binding.h"
#include "parser.h"
#include "preprocessor.h"
@@ -110,7 +112,8 @@ namespace Slang
//
String emitHLSLForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
auto compileRequest = entryPoint->compileRequest;
auto translationUnit = entryPoint->getTranslationUnit();
@@ -149,13 +152,15 @@ namespace Slang
{
return emitEntryPoint(
entryPoint,
- compileRequest->layout.Ptr(),
- CodeGenTarget::HLSL);
+ targetReq->layout.Ptr(),
+ CodeGenTarget::HLSL,
+ targetReq->target);
}
}
String emitGLSLForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
auto compileRequest = entryPoint->compileRequest;
auto translationUnit = entryPoint->getTranslationUnit();
@@ -194,8 +199,9 @@ namespace Slang
// so that we properly emit it as the `main` function.
return emitEntryPoint(
entryPoint,
- compileRequest->layout.Ptr(),
- CodeGenTarget::GLSL);
+ targetReq->layout.Ptr(),
+ CodeGenTarget::GLSL,
+ targetReq->target);
}
}
@@ -232,7 +238,8 @@ namespace Slang
}
List<uint8_t> EmitDXBytecodeForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
static pD3DCompile D3DCompile_ = nullptr;
if (!D3DCompile_)
@@ -246,7 +253,7 @@ namespace Slang
return List<uint8_t>();
}
- auto hlslCode = emitHLSLForEntryPoint(entryPoint);
+ auto hlslCode = emitHLSLForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(entryPoint->compileRequest, hlslCode.Buffer(), CodeGenTarget::HLSL);
ID3DBlob* codeBlob;
@@ -333,10 +340,11 @@ namespace Slang
}
String EmitDXBytecodeAssemblyForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
- List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(entryPoint);
+ List<uint8_t> dxbc = EmitDXBytecodeForEntryPoint(entryPoint, targetReq);
if (!dxbc.Count())
{
return String();
@@ -440,9 +448,10 @@ namespace Slang
}
List<uint8_t> emitSPIRVForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
- String rawGLSL = emitGLSLForEntryPoint(entryPoint);
+ String rawGLSL = emitGLSLForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(entryPoint->compileRequest, rawGLSL.Buffer(), CodeGenTarget::GLSL);
List<uint8_t> output;
@@ -473,9 +482,10 @@ namespace Slang
}
String emitSPIRVAssemblyForEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
- List<uint8_t> spirv = emitSPIRVForEntryPoint(entryPoint);
+ List<uint8_t> spirv = emitSPIRVForEntryPoint(entryPoint, targetReq);
if (spirv.Count() == 0)
return String();
@@ -484,26 +494,21 @@ namespace Slang
}
#endif
- List<uint8_t> emitSlangIRForEntryPoint(
- EntryPointRequest* entryPoint);
-
- String emitSlangIRAssemblyForEntryPoint(
- EntryPointRequest* entryPoint);
-
// Do emit logic for a single entry point
CompileResult emitEntryPoint(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq)
{
CompileResult result;
auto compileRequest = entryPoint->compileRequest;
- auto target = compileRequest->Target;
+ auto target = targetReq->target;
switch (target)
{
case CodeGenTarget::HLSL:
{
- String code = emitHLSLForEntryPoint(entryPoint);
+ String code = emitHLSLForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
@@ -511,7 +516,7 @@ namespace Slang
case CodeGenTarget::GLSL:
{
- String code = emitGLSLForEntryPoint(entryPoint);
+ String code = emitGLSLForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
@@ -520,7 +525,7 @@ namespace Slang
#if SLANG_ENABLE_DXBC_SUPPORT
case CodeGenTarget::DXBytecode:
{
- List<uint8_t> code = EmitDXBytecodeForEntryPoint(entryPoint);
+ List<uint8_t> code = EmitDXBytecodeForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
result = CompileResult(code);
}
@@ -528,7 +533,7 @@ namespace Slang
case CodeGenTarget::DXBytecodeAssembly:
{
- String code = EmitDXBytecodeAssemblyForEntryPoint(entryPoint);
+ String code = EmitDXBytecodeAssemblyForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
@@ -537,7 +542,7 @@ namespace Slang
case CodeGenTarget::SPIRV:
{
- List<uint8_t> code = emitSPIRVForEntryPoint(entryPoint);
+ List<uint8_t> code = emitSPIRVForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
result = CompileResult(code);
}
@@ -545,29 +550,11 @@ namespace Slang
case CodeGenTarget::SPIRVAssembly:
{
- String code = emitSPIRVAssemblyForEntryPoint(entryPoint);
- maybeDumpIntermediate(compileRequest, code.Buffer(), target);
- result = CompileResult(code);
- }
- break;
-
- case CodeGenTarget::SlangIR:
- {
- List<uint8_t> code = emitSlangIRForEntryPoint(entryPoint);
- maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
- result = CompileResult(code);
- }
- break;
-
-#if 0
- case CodeGenTarget::SlangIRAssembly:
- {
- String code = emitSlangIRAssemblyForEntryPoint(entryPoint);
+ String code = emitSPIRVAssemblyForEntryPoint(entryPoint, targetReq);
maybeDumpIntermediate(compileRequest, code.Buffer(), target);
result = CompileResult(code);
}
break;
-#endif
case CodeGenTarget::None:
// The user requested no output
@@ -632,11 +619,13 @@ namespace Slang
}
static void writeEntryPointResultToFile(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq,
+ UInt entryPointIndex)
{
auto compileRequest = entryPoint->compileRequest;
auto outputPath = entryPoint->outputPath;
- auto result = entryPoint->result;
+ auto result = targetReq->entryPointResults[entryPointIndex];
switch (result.format)
{
case ResultFormat::Text:
@@ -680,10 +669,12 @@ namespace Slang
}
static void writeEntryPointResultToStandardOutput(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq,
+ UInt entryPointIndex)
{
auto compileRequest = entryPoint->compileRequest;
- auto result = entryPoint->result;
+ auto& result = targetReq->entryPointResults[entryPointIndex];
switch (result.format)
{
@@ -699,7 +690,7 @@ namespace Slang
{
// Writing to console, so we need to generate text output.
- switch (compileRequest->Target)
+ switch (targetReq->target)
{
#if SLANG_ENABLE_DXBC_SUPPORT
case CodeGenTarget::DXBytecode:
@@ -750,88 +741,97 @@ namespace Slang
}
static void writeEntryPointResult(
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ TargetRequest* targetReq,
+ UInt entryPointIndex)
{
+ auto& result = targetReq->entryPointResults[entryPointIndex];
+
// Skip the case with no output
- if (entryPoint->result.format == ResultFormat::None)
+ if (result.format == ResultFormat::None)
return;
if (entryPoint->outputPath.Length())
{
- writeEntryPointResultToFile(entryPoint);
+ writeEntryPointResultToFile(entryPoint, targetReq, entryPointIndex);
}
else
{
- writeEntryPointResultToStandardOutput(entryPoint);
+ writeEntryPointResultToStandardOutput(entryPoint, targetReq, entryPointIndex);
}
}
- CompileResult emitTranslationUnitEntryPoints(
- TranslationUnitRequest* translationUnit)
+ void emitEntryPoints(
+ TargetRequest* targetReq)
{
- CompileResult result;
-
- for (auto& entryPoint : translationUnit->entryPoints)
- {
- CompileResult entryPointResult = emitEntryPoint(entryPoint.Ptr());
-
- entryPoint->result = entryPointResult;
- }
+ CompileRequest* compileReq = targetReq->compileRequest;
- // The result for the translation unit will just be the concatenation
- // of the results for each entry point. This doesn't actually make
- // much sense, but it is good enough for now.
- //
- // TODO: Replace this with a packaged JSON and/or binary format.
- for (auto& entryPoint : translationUnit->entryPoints)
- {
- result.append(entryPoint->result);
- }
-
- return result;
}
- // Do emit logic for an entire translation unit, which might
- // have zero or more entry points
- CompileResult emitTranslationUnit(
- TranslationUnitRequest* translationUnit)
+ void generateOutputForTarget(
+ TargetRequest* targetReq)
{
- return emitTranslationUnitEntryPoints(translationUnit);
- }
+ CompileRequest* compileReq = targetReq->compileRequest;
-#if 0
- TranslationUnitResult generateOutput(ExtraContext& context)
- {
- TranslationUnitResult result = emitTranslationUnit(context);
- return result;
+ // Generate target code any entry points that
+ // have been requested for compilation.
+ for (auto& entryPoint : compileReq->entryPoints)
+ {
+ CompileResult entryPointResult = emitEntryPoint(entryPoint, targetReq);
+ targetReq->entryPointResults.Add(entryPointResult);
+ }
}
-#endif
void generateOutput(
CompileRequest* compileRequest)
{
- // Start of with per-translation-unit and per-entry-point lowering
- for( auto translationUnit : compileRequest->translationUnits )
+ // Go through the code-generation targets that the user
+ // has specified, and generate code for each of them.
+ //
+ for (auto targetReq : compileRequest->targets)
{
- CompileResult translationUnitResult = emitTranslationUnit(translationUnit.Ptr());
- translationUnit->result = translationUnitResult;
+ generateOutputForTarget(targetReq);
+ }
+
+ // If we are being asked to generate code in a container
+ // format, then we are now in a position to do so.
+ switch (compileRequest->containerFormat)
+ {
+ default:
+ break;
+
+ case ContainerFormat::SlangModule:
+ generateBytecodeForCompileRequest(compileRequest);
+ break;
}
// If we are in command-line mode, we might be expected to actually
// write output to one or more files here.
- // But don't write any output if we were told to skip it.
- if (compileRequest->shouldSkipCodegen)
- return;
-
if (compileRequest->isCommandLineCompile)
{
- for( auto entryPoint : compileRequest->entryPoints )
+ for (auto targetReq : compileRequest->targets)
{
- writeEntryPointResult(entryPoint);
+ UInt entryPointCount = compileRequest->entryPoints.Count();
+ for (UInt ee = 0; ee < entryPointCount; ++ee)
+ {
+ writeEntryPointResult(
+ compileRequest->entryPoints[ee],
+ targetReq,
+ ee);
+ }
}
- }
+ if (compileRequest->containerOutputPath.Length() != 0)
+ {
+ auto& data = compileRequest->generatedBytecode;
+ writeOutputFile(compileRequest,
+ compileRequest->containerOutputPath,
+ data.begin(),
+ data.end() - data.begin(),
+ OutputFileKind::Binary);
+ }
+ }
}
// Debug logic for dumping intermediate outputs
@@ -943,15 +943,6 @@ namespace Slang
}
break;
#endif
-
- case CodeGenTarget::SlangIR:
- dumpIntermediateBinary(compileRequest, data, size, ".slang-ir");
- {
- // TODO: need to support dissassembly from Slang IR binary
-// String slangIRAssembly = dissassembleSlangIR(compileRequest, data, size);
-// dumpIntermediateText(compileRequest, slangIRAssembly.begin(), slangIRAssembly.Length(), ".slang-ir.asm");
- }
- break;
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 9de74e2cd..b7ab980fc 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -47,7 +47,12 @@ namespace Slang
SPIRVAssembly = SLANG_SPIRV_ASM,
DXBytecode = SLANG_DXBC,
DXBytecodeAssembly = SLANG_DXBC_ASM,
- SlangIR = SLANG_IR,
+ };
+
+ enum class ContainerFormat
+ {
+ None = SLANG_CONTAINER_FORMAT_NONE,
+ SlangModule = SLANG_CONTAINER_FORMAT_SLANG_MODULE,
};
enum class LineDirectiveMode : SlangLineDirectiveMode
@@ -108,13 +113,14 @@ namespace Slang
// (only used when compiling from the command line)
String outputPath;
- // The resulting output for the enry point
- //
- // TODO: low-level code generation should be a distinct step
- CompileResult result;
-
// The translation unit that this entry point came from
TranslationUnitRequest* getTranslationUnit();
+
+ // The declaration of the entry-point function itself.
+ // This will be filled in as part of semantic analysis;
+ // it should not be assumed to be available in cases
+ // where any errors were diagnosed.
+ RefPtr<FuncDecl> decl;
};
enum class PassThroughMode : SlangPassThrough
@@ -156,10 +162,27 @@ namespace Slang
// The parsed syntax for the translation unit
RefPtr<ModuleDecl> SyntaxNode;
- // The resulting output for the translation unit
- //
- // TODO: low-level code generation should be a distinct step
- CompileResult result;
+ // The IR-level code for this translation unit.
+ // This will only be valid/non-null after semantic
+ // checking and IR generation are complete, so it
+ // is not safe to use this field without testing for NULL.
+ IRModule* irModule;
+ };
+
+ // A request to generate output in some target format
+ class TargetRequest : public RefObject
+ {
+ public:
+ CompileRequest* compileRequest;
+ CodeGenTarget target;
+
+ // The resulting reflection layout information
+ RefPtr<ProgramLayout> layout;
+
+ // Generated compile results for each entry point
+ // in the parent compile request (indexing matches
+ // the order they are given in the compile request)
+ List<CompileResult> entryPointResults;
};
// A directory to be searched when looking for files (e.g., `#include`)
@@ -182,8 +205,15 @@ namespace Slang
// Pointer to parent session
Session* mSession;
- // What target language are we compiling to?
- CodeGenTarget Target = CodeGenTarget::Unknown;
+ // Information on the targets we are being asked to
+ // generate code for.
+ List<RefPtr<TargetRequest>> targets;
+
+ // What container format are we being asked to generate?
+ ContainerFormat containerFormat = ContainerFormat::None;
+
+ // Path to output container to
+ String containerOutputPath;
// Directories to search for `#include` files or `import`ed modules
List<SearchDirectory> searchDirectories;
@@ -235,18 +265,18 @@ namespace Slang
// Files that compilation depended on
List<String> mDependencyFilePaths;
- // The resulting reflection layout information
- RefPtr<ProgramLayout> layout;
+ // Generated bytecode representation of all the code
+ List<uint8_t> generatedBytecode;
// Modules that have been dynamically loaded via `import`
//
// This is a list of unique modules loaded, in the order they were encountered.
List<RefPtr<ModuleDecl> > loadedModulesList;
- // Map from the logical name of a module to its definition
+ // Map from the path of a module file to its definition
Dictionary<String, RefPtr<ModuleDecl>> mapPathToLoadedModule;
- // Map from the path of a module file to its definition
+ // Map from the logical name of a module to its definition
Dictionary<Name*, RefPtr<ModuleDecl>> mapNameToLoadedModules;
@@ -257,8 +287,12 @@ namespace Slang
void parseTranslationUnit(
TranslationUnitRequest* translationUnit);
+ // Perform primary semantic checking on all
+ // of the translation units in the program
void checkAllTranslationUnits();
+ void generateIR();
+
int executeActionsInner();
int executeActions();
@@ -282,6 +316,9 @@ namespace Slang
String const& name,
Profile profile);
+ UInt addTarget(
+ CodeGenTarget target);
+
RefPtr<ModuleDecl> loadModule(
Name* name,
String const& path,
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index a748b39bd..492c0e084 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -65,6 +65,9 @@ DIAGNOSTIC( 6, Error, outputPathsImplyDifferentFormats,
DIAGNOSTIC( 6, Error, cannotDeduceOutputFormatFromPath,
"cannot deduce an output format from the output path '$0'")
+DIAGNOSTIC( 6, Error, explicitOutputPathsAndMultipleTargets,
+ "canot use both explicit output paths ('-o') and multiple targets ('-target')")
+
//
// 1xxxx - Lexical anaylsis
//
@@ -308,6 +311,11 @@ DIAGNOSTIC(39999, Error, invalidFloatingPOintLiteralSuffix, "invalid suffix '$0'
DIAGNOSTIC(39999, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'")
DIAGNOSTIC(39999, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'")
+DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'")
+DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'")
+DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'")
+DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function")
+
//
// 4xxxx - IL code generation.
//
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 8dcbcdbf4..f9c0b5f09 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -4069,6 +4069,17 @@ emitDeclImpl(decl, nullptr);
return getIRName(decl);
}
+ switch (inst->op)
+ {
+ case kIROp_global_var:
+ case kIROp_Func:
+ return ((IRGlobalValue*)inst)->mangledName;
+ break;
+
+ default:
+ break;
+ }
+
StringBuilder sb;
sb << "_S";
sb << getID(inst);
@@ -5285,7 +5296,7 @@ emitDeclImpl(decl, nullptr);
String getIRFuncName(
IRFunc* func)
{
- if (isEntryPoint(func))
+ if (auto entryPointLayout = asEntryPoint(func))
{
// GLSL will always need to use `main` as the
// name for an entry-point function, but other
@@ -5297,10 +5308,7 @@ emitDeclImpl(decl, nullptr);
//
if (getTarget(context) != CodeGenTarget::GLSL)
{
- if (auto highLevelDeclDecoration = func->findDecoration<IRHighLevelDeclDecoration>())
- {
- return getText(highLevelDeclDecoration->decl->getName());
- }
+ return getText(entryPointLayout->entryPoint->getName());
}
//
@@ -5322,15 +5330,41 @@ emitDeclImpl(decl, nullptr);
// Deal with decorations that need
// to be emitted as attributes
- if (auto threadGroupSizeDecoration = func->findDecoration<IRComputeThreadGroupSizeDecoration>())
+ auto entryPointLayout = asEntryPoint(func);
+ if (entryPointLayout)
{
- emit("[numthreads(");
- for (int ii = 0; ii < 3; ++ii)
+ auto profile = entryPointLayout->profile;
+ auto stage = profile.GetStage();
+
+ switch (stage)
{
- if (ii != 0) emit(", ");
- Emit(threadGroupSizeDecoration->sizeAlongAxis[ii]);
+ case Stage::Compute:
+ {
+ static const UInt kAxisCount = 3;
+ UInt sizeAlongAxis[kAxisCount];
+
+ // TODO: this is kind of gross because we are using a public
+ // reflection API function, rather than some kind of internal
+ // utility it forwards to...
+ spReflectionEntryPoint_getComputeThreadGroupSize(
+ (SlangReflectionEntryPoint*)entryPointLayout,
+ kAxisCount,
+ &sizeAlongAxis[0]);
+
+ emit("[numthreads(");
+ for (int ii = 0; ii < 3; ++ii)
+ {
+ if (ii != 0) emit(", ");
+ Emit(sizeAlongAxis[ii]);
+ }
+ emit(")]\n");
+ }
+ break;
+
+ // TODO: There are other stages that will need this kind of handling.
+ default:
+ break;
}
- emit(")]\n");
}
auto name = getIRFuncName(func);
@@ -5509,12 +5543,17 @@ emitDeclImpl(decl, nullptr);
}
#endif
- bool isEntryPoint(IRFunc* func)
+ EntryPointLayout* asEntryPoint(IRFunc* func)
{
- if(func->findDecoration<IREntryPointDecoration>())
- return true;
+ if (auto layoutDecoration = func->findDecoration<IRLayoutDecoration>())
+ {
+ if (auto entryPointLayout = dynamic_cast<EntryPointLayout*>(layoutDecoration->layout))
+ {
+ return entryPointLayout;
+ }
+ }
- return false;
+ return nullptr;
}
// Detect if the given IR function represents a
@@ -6174,36 +6213,21 @@ EntryPointLayout* findEntryPointLayout(
return nullptr;
}
-String emitEntryPoint(
- EntryPointRequest* entryPoint,
- ProgramLayout* programLayout,
- CodeGenTarget target)
+// Given a layout computed for a whole program, find
+// the corresponding layout to use when looking up
+// variables at the global scope.
+//
+// It might be that the global scope was logically
+// mapped to a constant buffer, so that we need
+// to "unwrap" that declaration to get at the
+// actual struct type inside.
+StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout)
{
- auto translationUnit = entryPoint->getTranslationUnit();
- auto session = entryPoint->compileRequest->mSession;
-
- SharedEmitContext sharedContext;
- sharedContext.target = target;
- sharedContext.finalTarget = entryPoint->compileRequest->Target;
- sharedContext.entryPoint = entryPoint;
-
- if (entryPoint)
- {
- sharedContext.entryPointLayout = findEntryPointLayout(
- programLayout,
- entryPoint);
- }
-
- sharedContext.programLayout = programLayout;
-
- // Layout information for the global scope is either an ordinary
- // `struct` in the common case, or a constant buffer in the case
- // where there were global-scope uniforms.
auto globalScopeLayout = programLayout->globalScopeLayout;
- StructTypeLayout* globalStructLayout = nullptr;
if( auto gs = globalScopeLayout.As<StructTypeLayout>() )
{
- globalStructLayout = gs.Ptr();
+ return gs.Ptr();
}
else if( auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>() )
{
@@ -6229,12 +6253,42 @@ String emitEntryPoint(
// We expect all constant buffers to contain `struct` types for now
SLANG_RELEASE_ASSERT(elementTypeStructLayout);
- globalStructLayout = elementTypeStructLayout.Ptr();
+ return elementTypeStructLayout.Ptr();
}
else
{
SLANG_UNEXPECTED("uhandled global-scope binding layout");
+ return nullptr;
+ }
+}
+
+String emitEntryPoint(
+ EntryPointRequest* entryPoint,
+ ProgramLayout* programLayout,
+ CodeGenTarget target,
+ CodeGenTarget finalTarget)
+{
+ auto translationUnit = entryPoint->getTranslationUnit();
+ auto session = entryPoint->compileRequest->mSession;
+
+ SharedEmitContext sharedContext;
+ sharedContext.target = target;
+ sharedContext.finalTarget = finalTarget;
+ sharedContext.entryPoint = entryPoint;
+
+ if (entryPoint)
+ {
+ sharedContext.entryPointLayout = findEntryPointLayout(
+ programLayout,
+ entryPoint);
}
+
+ sharedContext.programLayout = programLayout;
+
+ // Layout information for the global scope is either an ordinary
+ // `struct` in the common case, or a constant buffer in the case
+ // where there were global-scope uniforms.
+ StructTypeLayout* globalStructLayout = getGlobalStructLayout(programLayout);
sharedContext.globalStructLayout = globalStructLayout;
auto translationUnitSyntax = translationUnit->SyntaxNode.Ptr();
@@ -6268,11 +6322,14 @@ String emitEntryPoint(
// This seems to be case (3), because the user is asking for full
// checking, and so we can assume we understand the code fully.
//
- // In this case we want to translate to our intermediate representation
- // and do optimizations/transformations there before we emit final code.
+ // The IR code for the module should already have been generated,
+ // so that we "just" need to specialize it as needed for the
+ // specific target and entry point in use.
//
-
- auto lowered = lowerEntryPointToIR(entryPoint, programLayout, target);
+ auto lowered = specializeIRForEntryPoint(
+ entryPoint,
+ programLayout,
+ target);
// debugging:
if (translationUnit->compileRequest->shouldDumpIR)
@@ -6280,18 +6337,12 @@ String emitEntryPoint(
dumpIR(lowered);
}
- // TODO: depending on the target we are trying to generate code for,
- // we may need to apply certain transformations, and we may also
- // need to link in (and then inline) target-specific implementations
- // for the library functions that the user called.
-
- switch(target)
- {
- case CodeGenTarget::GLSL:
- legalizeEntryPointsForGLSL(session, lowered);
- break;
- }
-
+ // TODO: we should apply some guaranteed transformations here,
+ // to eliminate constructs that aren't legal downstream (e.g. generics).
+ //
+ // TODO: Need to decide whether to do these before or after
+ // target-specific legalization steps. Currently I've folded
+ // legalization into the specialization above.
// TODO: do we want to emit directly from IR, or translate the
// IR back into AST for emission?
diff --git a/source/slang/emit.h b/source/slang/emit.h
index 5d546bdf4..e17a84d5a 100644
--- a/source/slang/emit.h
+++ b/source/slang/emit.h
@@ -22,6 +22,12 @@ namespace Slang
String emitEntryPoint(
EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
- CodeGenTarget target);
+
+ // The target language to generate code in (e.g., HLSL/GLSL)
+ CodeGenTarget target,
+
+ // The "final" target that we are being asked to compile for
+ // (e.g., SPIR-V, DXBC, ...).
+ CodeGenTarget finalTarget);
}
#endif
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 64c200b12..9ac79413f 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -35,24 +35,6 @@ struct IRLayoutDecoration : IRDecoration
Layout* layout;
};
-// Identifies a function as an entry point for some stage
-struct IREntryPointDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_EntryPoint };
-
- Profile profile;
- EntryPointLayout* layout;
-};
-
-// Associates a compute-shader entry point function
-// with a thread-group size.
-struct IRComputeThreadGroupSizeDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_ComputeThreadGroupSize };
-
- UInt sizeAlongAxis[3];
-};
-
enum IRLoopControl
{
kIRLoopControl_Unroll,
@@ -468,6 +450,15 @@ struct IRBuilder
IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout);
};
+// Generate a clone of an IR module that is specialized for
+// a particular entry point, target, etc.
+
+IRModule* specializeIRForEntryPoint(
+ EntryPointRequest* entryPointRequest,
+ ProgramLayout* programLayout,
+ CodeGenTarget target);
+
+
}
#endif
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 79d9883ea..31a35cd08 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3,6 +3,7 @@
#include "ir-insts.h"
#include "../core/basic.h"
+#include "mangle.h"
namespace Slang
{
@@ -244,40 +245,25 @@ namespace Slang
for( UInt aa = 0; aa < fixedArgCount; ++aa )
{
- operand->init(inst, fixedArgs[aa]);
+ if (fixedArgs)
+ {
+ operand->init(inst, fixedArgs[aa]);
+ }
operand++;
}
for( UInt aa = 0; aa < varArgCount; ++aa )
{
- operand->init(inst, varArgs[aa]);
+ if (varArgs)
+ {
+ operand->init(inst, varArgs[aa]);
+ }
operand++;
}
return inst;
}
- // Create an IR instruction/value and initialize it.
- //
- // For this overload, the type of the instruction is
- // folded into the argument list (so `args[0]` needs
- // to be the type of the instruction)
- static IRValue* createInstImpl(
- IRBuilder* builder,
- UInt size,
- IROp op,
- UInt argCount,
- IRValue* const* args)
- {
- return createInstImpl(
- builder,
- size,
- op,
- (IRType*) args[0],
- argCount - 1,
- args + 1);
- }
-
template<typename T>
static T* createInst(
IRBuilder* builder,
@@ -1997,12 +1983,10 @@ namespace Slang
void legalizeEntryPointForGLSL(
Session* session,
IRFunc* func,
- IREntryPointDecoration* entryPointInfo)
+ EntryPointLayout* entryPointLayout)
{
auto module = func->parentModule;
- auto entryPointLayout = entryPointInfo->layout;
-
// We require that the entry-point function has no uses,
// because otherwise we'd invalidate the signature
// at all existing call sites.
@@ -2235,28 +2219,447 @@ namespace Slang
// the way that things have been moved around.
}
- void legalizeEntryPointsForGLSL(
- Session* session,
- IRModule* module)
+ // Needed for lookup up entry-point layouts.
+ //
+ // TODO: maybe arrange so that codegen is driven from the layout layer
+ // instead of the input/request layer.
+ EntryPointLayout* findEntryPointLayout(
+ ProgramLayout* programLayout,
+ EntryPointRequest* entryPointRequest);
+
+ struct IRSpecSymbol : RefObject
+ {
+ IRGlobalValue* irGlobalValue;
+ RefPtr<IRSpecSymbol> nextWithSameName;
+ };
+
+ struct IRSpecContext
+ {
+ // The specialized module we are building
+ IRModule* module;
+
+ // The original, unspecialized module we are copying
+ IRModule* originalModule;
+
+ // The IR builder to use for creating nodes
+ IRBuilder* builder;
+
+ // A map from mangled symbol names to zero or
+ // more global IR values that have that name,
+ // in the *original* module.
+ Dictionary<String, RefPtr<IRSpecSymbol>> symbols;
+
+ // A map from the mangled name of a global variable
+ // to the layout to use for it.
+ Dictionary<String, VarLayout*> globalVarLayouts;
+
+ // A map from values in the original IR module
+ // to their equivalent in the cloned module.
+ Dictionary<IRValue*, IRValue*> clonedValues;
+ };
+
+ void registerClonedValue(
+ IRSpecContext* context,
+ IRValue* clonedValue,
+ IRValue* originalValue)
+ {
+ context->clonedValues.Add(originalValue, clonedValue);
+ }
+
+ void cloneDecorations(
+ IRSpecContext* context,
+ IRValue* clonedValue,
+ IRValue* originalValue)
{
- // We need to walk through all the global entry point
- // declarations, and transform them to comply with
- // GLSL rules.
- for( auto globalValue = module->getFirstGlobalValue(); globalValue; globalValue = globalValue->getNextValue())
+ for (auto dd = originalValue->firstDecoration; dd; dd = dd->next)
{
- // Is the global value a function?
- if(globalValue->op != kIROp_Func)
- continue;
- IRFunc* func = (IRFunc*) globalValue;
+ switch (dd->op)
+ {
+ case kIRDecorationOp_HighLevelDecl:
+ {
+ auto originalDecoration = (IRHighLevelDeclDecoration*)dd;
- // Is the function an entry point?
- IREntryPointDecoration* entryPointDecoration = func->findDecoration<IREntryPointDecoration>();
- if(!entryPointDecoration)
+ context->builder->addHighLevelDeclDecoration(clonedValue, originalDecoration->decl);
+ }
+ break;
+
+ default:
+ // Don't clone any decorations we don't understand.
+ break;
+ }
+ }
+
+ // TODO: implement this
+ }
+
+ IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar);
+ IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
+
+ IRValue* cloneValue(
+ IRSpecContext* context,
+ IRValue* originalValue)
+ {
+ IRValue* clonedValue = nullptr;
+ if (context->clonedValues.TryGetValue(originalValue, clonedValue))
+ return clonedValue;
+
+ switch (originalValue->op)
+ {
+ case kIROp_global_var:
+ return cloneGlobalVar(context, (IRGlobalVar*)originalValue);
+ break;
+
+ case kIROp_Func:
+ return cloneFunc(context, (IRFunc*)originalValue);
+ break;
+
+ case kIROp_boolConst:
+ {
+ IRConstant* c = (IRConstant*)originalValue;
+ return context->builder->getBoolValue(c->u.intVal != 0);
+ }
+ break;
+
+
+ case kIROp_IntLit:
+ {
+ IRConstant* c = (IRConstant*)originalValue;
+ return context->builder->getIntValue(c->type, c->u.intVal);
+ }
+ break;
+
+ case kIROp_FloatLit:
+ {
+ IRConstant* c = (IRConstant*)originalValue;
+ return context->builder->getFloatValue(c->type, c->u.floatVal);
+ }
+ break;
+
+ case kIROp_decl_ref:
+ {
+ IRDeclRef* od = (IRDeclRef*)originalValue;
+ return context->builder->getDeclRefVal(od->declRef);
+ }
+ break;
+
+ default:
+ SLANG_UNEXPECTED("no value registered for IR value");
+ return nullptr;
+ }
+ }
+
+ void cloneInst(
+ IRSpecContext* context,
+ IRBuilder* builder,
+ IRInst* originalInst)
+ {
+ switch (originalInst->op)
+ {
+ // TODO: are there any instruction types that need to be handled
+ // specially here? That would be anything that has more state
+ // than is visible in its operand list...
+
+ default:
+ {
+ // The common case is that we just need to construct a cloned
+ // instruction with the right number of operands, intialize
+ // it, and then add it to the sequence.
+ UInt argCount = originalInst->getArgCount();
+ IRInst* clonedInst = createInstWithTrailingArgs<IRInst>(
+ builder, originalInst->op, originalInst->type,
+ 0, nullptr,
+ argCount, nullptr);
+ builder->addInst(clonedInst);
+ registerClonedValue(context, clonedInst, originalInst);
+
+ cloneDecorations(context, clonedInst, originalInst);
+
+ for (UInt aa = 0; aa < argCount; ++aa)
+ {
+ IRValue* originalArg = originalInst->getArg(aa);
+ IRValue* clonedArg = cloneValue(context, originalArg);
+
+ clonedInst->getArgs()[aa].init(clonedInst, clonedArg);
+ }
+ }
+
+ break;
+ }
+ }
+
+ IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar)
+ {
+ auto clonedVar = context->builder->createGlobalVar(originalVar->getType()->getValueType());
+ registerClonedValue(context, clonedVar, originalVar);
+
+ auto mangledName = originalVar->mangledName;
+ clonedVar->mangledName = mangledName;
+
+ cloneDecorations(context, clonedVar, originalVar);
+
+ VarLayout* layout = nullptr;
+ if (context->globalVarLayouts.TryGetValue(mangledName, layout))
+ {
+ context->builder->addLayoutDecoration(clonedVar, layout);
+ }
+
+ // TODO: once we support initializers on global variables,
+ // we'll need to handle cloning it here.
+
+ return clonedVar;
+ }
+
+ void cloneFunctionCommon(
+ IRSpecContext* context,
+ IRFunc* clonedFunc,
+ IRFunc* originalFunc)
+ {
+ // First clone all the simple properties.
+ clonedFunc->mangledName = originalFunc->mangledName;
+ clonedFunc->genericParams = originalFunc->genericParams;
+ clonedFunc->type = originalFunc->type;
+
+ cloneDecorations(context, clonedFunc, originalFunc);
+
+ // Next we are going to clone the actual code.
+ IRBuilder builderStorage = *context->builder;
+ IRBuilder* builder = &builderStorage;
+ builder->func = clonedFunc;
+
+ // We will walk through the blocks of the function, and clone each of them.
+ //
+ // We need to create the cloned blocks first, and then walk through them,
+ // because blocks might be forward referenced (this is not possible
+ // for other cases of instructions).
+ for (auto originalBlock = originalFunc->getFirstBlock();
+ originalBlock;
+ originalBlock = originalBlock->getNextBlock())
+ {
+ IRBlock* clonedBlock = builder->createBlock();
+ clonedFunc->addBlock(clonedBlock);
+ registerClonedValue(context, clonedBlock, originalBlock);
+
+ // We can go ahead and clone parameters here, while we are at it.
+ builder->block = clonedBlock;
+ for (auto originalParam = originalBlock->getFirstParam();
+ originalParam;
+ originalParam = originalParam->getNextParam())
+ {
+ IRParam* clonedParam = builder->emitParam(originalParam->getType());
+ registerClonedValue(context, clonedParam, originalParam);
+ }
+ }
+
+ // Okay, now we are in a good position to start cloning
+ // the instructions inside the blocks.
+ {
+ IRBlock* ob = originalFunc->getFirstBlock();
+ IRBlock* cb = clonedFunc->getFirstBlock();
+ while (ob)
+ {
+ assert(cb);
+
+ builder->block = cb;
+ for (auto oi = ob->getFirstInst(); oi; oi = oi->nextInst)
+ {
+ cloneInst(context, builder, oi);
+ }
+
+ ob = ob->getNextBlock();
+ cb = cb->getNextBlock();
+ }
+ }
+
+ // Shuffle the function to the end of the list, because
+ // it needs to follow its dependencies.
+ //
+ // TODO: This isn't really a good requirement to place on the IR...
+ clonedFunc->removeFromParent();
+ clonedFunc->insertAtEnd(context->module);
+ }
+
+ IRFunc* specializeIRForEntryPoint(
+ IRSpecContext* context,
+ EntryPointRequest* entryPointRequest,
+ EntryPointLayout* entryPointLayout)
+ {
+ // Look up the IR symbol by name
+ String mangledName = getMangledName(entryPointRequest->decl);
+ RefPtr<IRSpecSymbol> sym;
+ if (!context->symbols.TryGetValue(mangledName, sym))
+ {
+ SLANG_UNEXPECTED("no matching IR symbol");
+ return nullptr;
+ }
+
+ // TODO: deal with the case where we might
+ // have multiple versions...
+
+ auto globalValue = sym->irGlobalValue;
+ if (globalValue->op != kIROp_Func)
+ {
+ SLANG_UNEXPECTED("expected an IR function");
+ return nullptr;
+ }
+ auto originalFunc = (IRFunc*)globalValue;
+
+ // Create a clone for the IR function
+ auto clonedFunc = context->builder->createFunc();
+
+ // Note: we do *not* register this cloned declaration
+ // as the cloned value for the original symbol.
+ // This is kind of a kludge, but it ensures that
+ // in the unlikely case that the function is both
+ // used as an entry point and a callable function
+ // (yes, this would imply recursion...) we actually
+ // have two copies, which lets us arbitrarily
+ // transform the entry point to meet target requirements.
+ //
+ // TODO: The above statement is kind of bunk, though,
+ // because both versions of the function would have
+ // the same mangled name... :(
+
+ // We need to clone all the properties of the original
+ // function, including any blocks, their parameters,
+ // and their instructions.
+ cloneFunctionCommon(context, clonedFunc, originalFunc);
+
+ // We need to attach the layout information for
+ // the entry point to this declaration, so that
+ // we can use it to inform downstream code emit.
+ context->builder->addLayoutDecoration(
+ clonedFunc,
+ entryPointLayout);
+
+ return clonedFunc;
+ }
+
+ // The case for functions that are not the entry point is
+ // strictly simpler, so that is nice.
+ IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc)
+ {
+ // TODO: We really need to scan through all the various
+ // global function symbols that have the same mangled name,
+ // and pick the correct one to lower for the target.
+
+ auto clonedFunc = context->builder->createFunc();
+ registerClonedValue(context, clonedFunc, originalFunc);
+ cloneFunctionCommon(context, clonedFunc, originalFunc);
+ return clonedFunc;
+ }
+
+ StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout);
+
+ IRModule* specializeIRForEntryPoint(
+ EntryPointRequest* entryPointRequest,
+ ProgramLayout* programLayout,
+ CodeGenTarget target)
+ {
+ auto compileRequest = entryPointRequest->compileRequest;
+ auto session = compileRequest->mSession;
+ auto translationUnit = entryPointRequest->getTranslationUnit();
+ auto originalIRModule = translationUnit->irModule;
+ if (!originalIRModule)
+ {
+ // We should already have emitted IR for the original
+ // translation unit, and it we don't have it, then
+ // we are now in trouble.
+ return nullptr;
+ }
+
+ auto entryPointLayout = findEntryPointLayout(programLayout, entryPointRequest);
+
+ // We now need to start cloning IR symbols from `originalIRModule`
+ // into a fresh IR module for this entry point. Along the way we need to:
+ //
+ // 1. Attach layout information from `programLayout` and/or `entryPointLayout`
+ // onto the cloned IR symbols, to drive later code generation.
+ //
+ // 2. In cases where a function might have multiple target-specific definitions,
+ // we need to pick the "best" one for the chosen code generation target.
+ //
+
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = compileRequest->mSession;
+
+ IRBuilder builderStorage;
+ IRBuilder* builder = &builderStorage;
+ builder->shared = sharedBuilder;
+
+ IRModule* module = builder->createModule();
+ sharedBuilder->module = module;
+
+ //
+
+ IRSpecContext contextStorage;
+ IRSpecContext* context = &contextStorage;
+
+ context->builder = builder;
+ context->module = module;
+ context->originalModule = originalIRModule;
+
+ // First, we will populate a map with all of the IR values
+ // that use the same mangled name, to make lookup easier
+ // in other steps.
+ for (auto gv = originalIRModule->firstGlobalValue; gv; gv = gv->nextGlobalValue)
+ {
+ String mangledName = gv->mangledName;
+ if (mangledName == "")
continue;
- // Okay, we need to legalize this one.
- legalizeEntryPointForGLSL(session, func, entryPointDecoration);
+ RefPtr<IRSpecSymbol> sym = new IRSpecSymbol();
+ sym->irGlobalValue = gv;
+
+ RefPtr<IRSpecSymbol> prev;
+ if (context->symbols.TryGetValue(mangledName, prev))
+ {
+ sym->nextWithSameName = prev->nextWithSameName;
+ prev->nextWithSameName = sym;
+ }
+ else
+ {
+ context->symbols.Add(mangledName, sym);
+ }
+ }
+
+ // Next, we want to optimize lookup over
+ auto globalStructLayout = getGlobalStructLayout(programLayout);
+ for (auto globalVarLayout : globalStructLayout->fields)
+ {
+ String mangledName = getMangledName(globalVarLayout->varDecl);
+ context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout);
}
+
+ // Next, we make sure to clone the global value for
+ // the entry point function itself, and rely on
+ // this step to recursively copy over anything else
+ // it might reference.
+ auto irEntryPoint = specializeIRForEntryPoint(context, entryPointRequest, entryPointLayout);
+
+ // TODO: *technically* we should consider the case where
+ // we have global variables with initializers, since
+ // these should get run whether or not the entry point
+ // references them.
+
+ // Depending on the downstream target, we may need to apply some
+ // guaranteed transformations to legalize things. We will go
+ // ahead and apply there here for now.
+ switch (target)
+ {
+ case CodeGenTarget::GLSL:
+ {
+ legalizeEntryPointForGLSL(session, irEntryPoint, entryPointLayout);
+ }
+ break;
+
+ default:
+ break;
+ }
+
+ return module;
}
diff --git a/source/slang/ir.h b/source/slang/ir.h
index 4fd165c33..ecc77dbc4 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -97,10 +97,7 @@ enum IRDecorationOp : uint16_t
{
kIRDecorationOp_HighLevelDecl,
kIRDecorationOp_Layout,
- kIRDecorationOp_EntryPoint,
- kIRDecorationOp_ComputeThreadGroupSize,
kIRDecorationOp_LoopControl,
- kIRDecorationOp_MangledName,
};
// A "decoration" that gets applied to an instruction.
@@ -291,6 +288,11 @@ struct IRGlobalValue : IRValue
{
IRModule* parentModule;
+ // The mangled name, for a symbol that should have linkage,
+ // or which might have multiple declarations.
+ String mangledName;
+
+
IRGlobalValue* nextGlobalValue;
IRGlobalValue* prevGlobalValue;
@@ -319,10 +321,6 @@ struct IRFunc : IRGlobalValue
// The type of the IR-level function
IRFuncType* getType() { return (IRFuncType*) type.Ptr(); }
- // The mangled name, for a function
- // that should have linkage.
- String mangledName;
-
// Any generic parameters this function has
List<RefPtr<Decl>> genericParams;
@@ -367,14 +365,6 @@ void printSlangIRAssembly(StringBuilder& builder, IRModule* module);
String getSlangIRAssembly(IRModule* module);
void dumpIR(IRModule* module);
-
-// IR transformations
-
-// Transform shader entry points so that they conform to GLSL rules.
-void legalizeEntryPointsForGLSL(
- Session* session,
- IRModule* module);
-
}
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 29a1fec2f..d4551421f 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -268,9 +268,7 @@ LoweredValInfo LoweredValInfo::swizzledLValue(
struct SharedIRGenContext
{
- EntryPointRequest* entryPoint;
- ProgramLayout* programLayout;
- CodeGenTarget target;
+ CompileRequest* compileRequest;
Dictionary<DeclRef<Decl>, LoweredValInfo> declValues;
@@ -292,7 +290,7 @@ struct IRGenContext
Session* getSession()
{
- return shared->entryPoint->compileRequest->mSession;
+ return shared->compileRequest->mSession;
}
};
@@ -729,8 +727,7 @@ void lowerStmt(
LoweredValInfo lowerDecl(
IRGenContext* context,
- DeclBase* decl,
- Layout* layout);
+ DeclBase* decl);
IRType* getIntType(
IRGenContext* context)
@@ -907,8 +904,7 @@ struct LoweringVisitor
LoweredValInfo createVar(
IRGenContext* context,
RefPtr<Type> type,
- Decl* decl = nullptr,
- Layout* layout = nullptr)
+ Decl* decl = nullptr)
{
auto builder = context->irBuilder;
auto irAlloc = builder->emitVar(type);
@@ -918,11 +914,6 @@ LoweredValInfo createVar(
builder->addHighLevelDeclDecoration(irAlloc, decl);
}
- if (layout)
- {
- builder->addLayoutDecoration(irAlloc, layout);
- }
-
return LoweredValInfo::ptr(irAlloc);
}
@@ -1817,7 +1808,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// be lifted later (pushing capture analysis
// down to the IR).
//
- lowerDecl(context, stmt->decl, nullptr);
+ lowerDecl(context, stmt->decl);
}
void visitSeqStmt(SeqStmt* stmt)
@@ -1991,18 +1982,12 @@ top:
struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
IRGenContext* context;
- Layout* layout;
IRBuilder* getBuilder()
{
return context->irBuilder;
}
- Layout* getLayout()
- {
- return layout;
- }
-
LoweredValInfo visitDeclBase(DeclBase* decl)
{
SLANG_UNIMPLEMENTED_X("decl catch-all");
@@ -2086,17 +2071,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto builder = getBuilder();
auto irGlobal = builder->createGlobalVar(varType);
+ irGlobal->mangledName = getMangledName(decl);
if (decl)
{
builder->addHighLevelDeclDecoration(irGlobal, decl);
}
- if (auto layout = getLayout())
- {
- builder->addLayoutDecoration(irGlobal, layout);
- }
-
// A global variable's SSA value is a *pointer* to
// the underlying storage.
auto globalVal = LoweredValInfo::ptr(irGlobal);
@@ -2168,7 +2149,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
varType = context->getSession()->getGroupSharedType(varType);
}
- LoweredValInfo varVal = createVar(context, varType, decl, getLayout());
+ LoweredValInfo varVal = createVar(context, varType, decl);
if( auto initExpr = decl->initExpr )
{
@@ -2840,11 +2821,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo lowerDecl(
IRGenContext* context,
- DeclBase* decl,
- Layout* layout)
+ DeclBase* decl)
{
DeclLoweringVisitor visitor;
- visitor.layout = layout;
visitor.context = context;
return visitor.dispatch(decl);
}
@@ -2870,100 +2849,31 @@ LoweredValInfo ensureDecl(
subContext.irBuilder = &subIRBuilder;
- RefPtr<VarLayout> layout;
- auto globalScopeLayout = shared->programLayout->globalScopeLayout;
- if (auto globalParameterBlockLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
- {
- globalScopeLayout = globalParameterBlockLayout->elementTypeLayout;
- }
- if (auto globalStructTypeLayout = globalScopeLayout.As<StructTypeLayout>())
- {
- globalStructTypeLayout->mapVarToLayout.TryGetValue(declRef.getDecl(), layout);
- }
-
- result = lowerDecl(&subContext, declRef.getDecl(), layout);
+ result = lowerDecl(&subContext, declRef.getDecl());
shared->declValues[declRef] = result;
return result;
}
-
-EntryPointLayout* findEntryPointLayout(
- SharedIRGenContext* shared,
- EntryPointRequest* entryPointRequest)
-{
- for( auto entryPointLayout : shared->programLayout->entryPoints )
- {
- if(entryPointLayout->entryPoint->getName() != entryPointRequest->name)
- continue;
-
- if(entryPointLayout->profile != entryPointRequest->profile)
- continue;
-
- // TODO: can't easily filter on translation unit here...
- // Ideally the `EntryPointRequest` should get filled in with a pointer
- // the specific function declaration that represents the entry point.
-
- return entryPointLayout.Ptr();
- }
-
- return nullptr;
-}
-
static void lowerEntryPointToIR(
IRGenContext* context,
- EntryPointRequest* entryPointRequest,
- EntryPointLayout* entryPointLayout)
+ EntryPointRequest* entryPointRequest)
{
// First, lower the entry point like an ordinary function
- auto entryPointFuncDecl = entryPointLayout->entryPoint;
- auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl, entryPointLayout);
- auto irFunc = getSimpleVal(context, loweredEntryPointFunc);
-
- auto builder = context->irBuilder;
-
- // We are going to attach all the entry-point-specific information
- // to the declaration as meta-data decorations for now.
- //
- // I'm not convinced this is the right way to go, but it is
- // the easiest and most expedient thing.
- //
- auto profile = entryPointRequest->profile;
- auto stage = profile.GetStage();
-
- auto entryPointDecoration = builder->addDecoration<IREntryPointDecoration>(irFunc);
- entryPointDecoration->profile = profile;
- entryPointDecoration->layout = entryPointLayout;
-
- // Attach layout information here.
- builder->addLayoutDecoration(irFunc, entryPointLayout);
-
- // Next, we need to start attaching the meta-data that is
- // required based on the particular stage we are targetting:
- switch (stage)
+ auto entryPointFuncDecl = entryPointRequest->decl;
+ if (!entryPointFuncDecl)
{
- case Stage::Compute:
- {
- // We need to attach information about the thread group size here.
- auto threadGroupSizeDecoration = builder->addDecoration<IRComputeThreadGroupSizeDecoration>(irFunc);
- static const UInt kAxisCount = 3;
-
- // TODO: this is kind of gross because we are using a public
- // reflection API function, rather than some kind of internal
- // utility it forwards to...
- spReflectionEntryPoint_getComputeThreadGroupSize(
- (SlangReflectionEntryPoint*)entryPointLayout,
- kAxisCount,
- &threadGroupSizeDecoration->sizeAlongAxis[0]);
- }
- break;
-
- default:
- break;
+ // Something must have gone wrong earlier, if we
+ // weren't able to associate a declaration with
+ // the entry point request.
+ return;
}
+
+ auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl);
}
+#if 0
IRModule* lowerEntryPointToIR(
EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
@@ -3002,7 +2912,55 @@ IRModule* lowerEntryPointToIR(
return module;
}
+#endif
+
+IRModule* generateIRForTranslationUnit(
+ TranslationUnitRequest* translationUnit)
+{
+ auto compileRequest = translationUnit->compileRequest;
+
+ SharedIRGenContext sharedContextStorage;
+ SharedIRGenContext* sharedContext = &sharedContextStorage;
+
+ sharedContext->compileRequest = compileRequest;
+
+ IRGenContext contextStorage;
+ IRGenContext* context = &contextStorage;
+
+ context->shared = sharedContext;
+
+ SharedIRBuilder sharedBuilderStorage;
+ SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
+ sharedBuilder->module = nullptr;
+ sharedBuilder->session = compileRequest->mSession;
+
+ IRBuilder builderStorage;
+ IRBuilder* builder = &builderStorage;
+ builder->shared = sharedBuilder;
+
+ IRModule* module = builder->createModule();
+ sharedBuilder->module = module;
+
+ context->irBuilder = builder;
+
+ // We need to emit IR for all public/exported symbols
+ // in the translation unit.
+ for (auto entryPoint : translationUnit->entryPoints)
+ {
+ lowerEntryPointToIR(context, entryPoint);
+ }
+
+ // If we are being sked to dump IR during compilation,
+ // then we can dump the initial IR for the module here.
+ if(compileRequest->shouldDumpIR)
+ {
+ dumpIR(module);
+ }
+
+ return module;
+}
+#if 0
String emitSlangIRAssemblyForEntryPoint(
EntryPointRequest* entryPoint)
{
@@ -3015,6 +2973,7 @@ String emitSlangIRAssemblyForEntryPoint(
return getSlangIRAssembly(irModule);
}
+#endif
} // namespace Slang
diff --git a/source/slang/lower-to-ir.h b/source/slang/lower-to-ir.h
index aa2cef631..bd878d6fa 100644
--- a/source/slang/lower-to-ir.h
+++ b/source/slang/lower-to-ir.h
@@ -13,15 +13,14 @@
namespace Slang
{
+ class CompileRequest;
class EntryPointRequest;
class ProgramLayout;
class TranslationUnitRequest;
struct ExtensionUsageTracker;
- IRModule* lowerEntryPointToIR(
- EntryPointRequest* entryPoint,
- ProgramLayout* programLayout,
- CodeGenTarget target);
+ IRModule* generateIRForTranslationUnit(
+ TranslationUnitRequest* translationUnit);
}
#endif
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 1387abc23..3e6ee9917 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -4556,49 +4556,8 @@ struct LoweringVisitor
};
-static RefPtr<StructTypeLayout> getGlobalStructLayout(
- ProgramLayout* programLayout)
-{
- // Layout information for the global scope is either an ordinary
- // `struct` in the common case, or a constant buffer in the case
- // where there were global-scope uniforms.
- auto globalScopeLayout = programLayout->globalScopeLayout;
- StructTypeLayout* globalStructLayout = globalScopeLayout.As<StructTypeLayout>();
- if(globalStructLayout)
- { }
- else if(auto globalConstantBufferLayout = globalScopeLayout.As<ParameterBlockTypeLayout>())
- {
- // TODO: the `cbuffer` case really needs to be emitted very
- // carefully, but that is beyond the scope of what a simple rewriter
- // can easily do (without semantic analysis, etc.).
- //
- // The crux of the problem is that we need to collect all the
- // global-scope uniforms (but not declarations that don't involve
- // uniform storage...) and put them in a single `cbuffer` declaration,
- // so that we can give it an explicit location. The fields in that
- // declaration might use various type declarations, so we'd really
- // need to emit all the type declarations first, and that involves
- // some large scale reorderings.
- //
- // For now we will punt and just emit the declarations normally,
- // and hope that the global-scope block (`$Globals`) gets auto-assigned
- // the same location that we manually asigned it.
-
- auto elementTypeLayout = globalConstantBufferLayout->elementTypeLayout;
- auto elementTypeStructLayout = elementTypeLayout.As<StructTypeLayout>();
-
- // We expect all constant buffers to contain `struct` types for now
- SLANG_RELEASE_ASSERT(elementTypeStructLayout);
-
- globalStructLayout = elementTypeStructLayout.Ptr();
- }
- else
- {
- SLANG_UNEXPECTED("unhandled type for global-scope parameter layout");
- }
- return globalStructLayout;
-}
-
+StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout);
// Determine if the user is just trying to "rewrite" their input file
// into an output file. This will affect the way we approach code
diff --git a/source/slang/mangle.h b/source/slang/mangle.h
index 6eea96a19..286e2c2c3 100644
--- a/source/slang/mangle.h
+++ b/source/slang/mangle.h
@@ -7,7 +7,7 @@
namespace Slang
{
- struct Decl;
+ class Decl;
String getMangledName(Decl* decl);
}
diff --git a/source/slang/options.cpp b/source/slang/options.cpp
index 79832a9c3..971d17b51 100644
--- a/source/slang/options.cpp
+++ b/source/slang/options.cpp
@@ -223,6 +223,11 @@ struct OptionsParser
#undef CASE
+ else if (path.EndsWith(".slang-module"))
+ {
+ spSetOutputContainerFormat(compileRequest, SLANG_CONTAINER_FORMAT_SLANG_MODULE);
+ requestImpl->containerOutputPath = path;
+ }
else
{
// Allow an unknown-format `-o`, assuming we get a target format
@@ -600,6 +605,16 @@ struct OptionsParser
}
}
+ // If the user is requesting multiple targets, *and* is asking
+ // for direct output files for entry points, that is an error.
+ if (rawOutputPaths.Count() != 0 && requestImpl->targets.Count() > 1)
+ {
+ requestImpl->mSink.diagnose(
+ SourceLoc(),
+ Diagnostics::explicitOutputPathsAndMultipleTargets);
+ }
+
+
// Did the user try to specify output path(s)?
if (rawOutputPaths.Count() != 0)
{
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index 8cc93fdd2..4fbec5652 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -1298,30 +1298,14 @@ static RefPtr<TypeLayout> processEntryPointParameter(
static void collectEntryPointParameters(
ParameterBindingContext* context,
- EntryPointRequest* entryPoint,
- ModuleDecl* translationUnitSyntax)
+ EntryPointRequest* entryPoint)
{
- // First, look for the entry point with the specified name
-
- // Make sure we've got a query-able member dictionary
- buildMemberDictionary(translationUnitSyntax);
-
- Decl* entryPointDecl;
- if( !translationUnitSyntax->memberDictionary.TryGetValue(entryPoint->name, entryPointDecl) )
- {
- // No such entry point!
- return;
- }
- if( entryPointDecl->nextInContainerWithSameName )
- {
- // Not the only decl of that name!
- return;
- }
-
- FuncDecl* entryPointFuncDecl = dynamic_cast<FuncDecl*>(entryPointDecl);
- if( !entryPointFuncDecl )
+ FuncDecl* entryPointFuncDecl = entryPoint->decl;
+ if (!entryPointFuncDecl)
{
- // Not a function!
+ // Something must have failed earlier, so that
+ // we didn't find a declaration to match this
+ // entry point request.
return;
}
@@ -1504,7 +1488,7 @@ static void collectParameters(
for( auto& entryPoint : translationUnit->entryPoints )
{
context->stage = entryPoint->profile.GetStage();
- collectEntryPointParameters(context, entryPoint.Ptr(), translationUnit->SyntaxNode.Ptr());
+ collectEntryPointParameters(context, entryPoint.Ptr());
}
}
@@ -1515,9 +1499,14 @@ static void collectParameters(
}
}
-static bool isGLSLCrossCompilerNeeded(CompileRequest* request)
+static bool isGLSLCrossCompilerNeeded(
+ TargetRequest* targetReq)
{
- switch (request->Target)
+ auto compileReq = targetReq->compileRequest;
+
+ // We only need cross-compilation if we
+ // are targetting something GLSL-based.
+ switch (targetReq->target)
{
default:
return false;
@@ -1528,23 +1517,33 @@ static bool isGLSLCrossCompilerNeeded(CompileRequest* request)
break;
}
- if (request->loadedModulesList.Count() != 0)
+ // If we `import`ed any Slang code, then the
+ // cross compiler is definitely needed, to
+ // translate that Slang over to GLSL.
+ if (compileReq->loadedModulesList.Count() != 0)
return true;
- for (auto tu : request->translationUnits)
+ // If there are any non-GLSL translation units,
+ // then we need to cross compile those...
+ for (auto tu : compileReq->translationUnits)
{
if (tu->sourceLanguage != SourceLanguage::GLSL)
return true;
}
+ // If we get to this point, then we have plain vanilla
+ // GLSL input, with no `import` declarations, so we
+ // are able to output GLSL without cross compilation.
return false;
}
void generateParameterBindings(
- CompileRequest* request)
+ TargetRequest* targetReq)
{
+ CompileRequest* compileReq = targetReq->compileRequest;
+
// Try to find rules based on the selected code-generation target
- auto rules = GetLayoutRulesFamilyImpl(request->Target);
+ auto rules = GetLayoutRulesFamilyImpl(targetReq->target);
// If there was no target, or there are no rules for the target,
// then bail out here.
@@ -1556,7 +1555,7 @@ void generateParameterBindings(
// Create a context to hold shared state during the process
// of generating parameter bindings
SharedParameterBindingContext sharedContext;
- sharedContext.compileRequest = request;
+ sharedContext.compileRequest = compileReq;
sharedContext.defaultLayoutRules = rules;
sharedContext.programLayout = programLayout;
@@ -1568,7 +1567,7 @@ void generateParameterBindings(
context.layoutRules = sharedContext.defaultLayoutRules;
// Walk through AST to discover all the parameters
- collectParameters(&context, request);
+ collectParameters(&context, compileReq);
// Now walk through the parameters to generate initial binding information
for( auto& parameter : sharedContext.parameters )
@@ -1692,7 +1691,7 @@ void generateParameterBindings(
//
// We only want to do this if the GLSL cross-compilation support is
// being invoked, so that we don't gum up other shaders.
- if(isGLSLCrossCompilerNeeded(request))
+ if(isGLSLCrossCompilerNeeded(targetReq))
{
UInt space = 0;
auto hackSamplerUsedRanges = findUsedRangeSetForSpace(&context, space);
@@ -1702,8 +1701,8 @@ void generateParameterBindings(
programLayout->bindingForHackSampler = (int)binding;
RefPtr<Variable> var = new Variable();
- var->nameAndLoc.name = request->getNamePool()->getName("SLANG_hack_samplerForTexelFetch");
- var->type.type = getSamplerStateType(request->mSession);
+ var->nameAndLoc.name = compileReq->getNamePool()->getName("SLANG_hack_samplerForTexelFetch");
+ var->type.type = getSamplerStateType(compileReq->mSession);
auto typeLayout = new TypeLayout();
typeLayout->type = var->type.type;
@@ -1724,7 +1723,7 @@ void generateParameterBindings(
// We now have a bunch of layout information, which we should
// record into a suitable object that represents the program
programLayout->globalScopeLayout = globalScopeLayout;
- request->layout = programLayout;
+ targetReq->layout = programLayout;
}
}
diff --git a/source/slang/parameter-binding.h b/source/slang/parameter-binding.h
index 264163974..eb093821f 100644
--- a/source/slang/parameter-binding.h
+++ b/source/slang/parameter-binding.h
@@ -8,7 +8,7 @@
namespace Slang {
-class CompileRequest;
+class TargetRequest;
// The parameter-binding interface is responsible for assigning
// binding locations/registers to every parameter of a shader
@@ -24,7 +24,7 @@ class CompileRequest;
// of the program.
void generateParameterBindings(
- CompileRequest* compileRequest);
+ TargetRequest* targetReq);
}
diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h
index 513ba3078..84153ee46 100644
--- a/source/slang/profile-defs.h
+++ b/source/slang/profile-defs.h
@@ -47,8 +47,6 @@ LANGUAGE(GLSL_ES, glsl_es)
LANGUAGE(GLSL_VK, glsl_vk)
LANGUAGE(SPIRV, spirv)
LANGUAGE(SPIRV_GL, spirv_gl)
-LANGUAGE(SlangIR, slang_ir)
-LANGUAGE(SlangIRAssembly, slang_ir_assembly)
LANGUAGE_ALIAS(GLSL, glsl_gl)
LANGUAGE_ALIAS(SPIRV, spirv_vk)
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 3ef3effab..435b4db3f 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -798,23 +798,22 @@ SLANG_API int spReflectionEntryPoint_usesAnySampleRateInput(
// Shader Reflection
+namespace Slang
+{
+ StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout);
+}
+
SLANG_API unsigned spReflection_GetParameterCount(SlangReflection* inProgram)
{
auto program = convert(inProgram);
if(!program) return 0;
- auto globalLayout = program->globalScopeLayout;
- if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>())
- {
- globalLayout = globalConstantBufferLayout->elementTypeLayout;
- }
-
- if(auto globalStructLayout = globalLayout.As<StructTypeLayout>())
- {
- return (unsigned) globalStructLayout->fields.Count();
- }
+ auto globalStructLayout = getGlobalStructLayout(program);
+ if (!globalStructLayout)
+ return 0;
- return 0;
+ return (unsigned) globalStructLayout->fields.Count();
}
SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflection* inProgram, unsigned index)
@@ -822,18 +821,11 @@ SLANG_API SlangReflectionParameter* spReflection_GetParameterByIndex(SlangReflec
auto program = convert(inProgram);
if(!program) return nullptr;
- auto globalLayout = program->globalScopeLayout;
- if(auto globalConstantBufferLayout = globalLayout.As<ParameterBlockTypeLayout>())
- {
- globalLayout = globalConstantBufferLayout->elementTypeLayout;
- }
-
- if(auto globalStructLayout = globalLayout.As<StructTypeLayout>())
- {
- return convert(globalStructLayout->fields[index].Ptr());
- }
+ auto globalStructLayout = getGlobalStructLayout(program);
+ if (!globalStructLayout)
+ return 0;
- return nullptr;
+ return convert(globalStructLayout->fields[index].Ptr());
}
SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* inProgram)
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 6e57c104c..dff322c29 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -2,6 +2,7 @@
#include "../core/slang-io.h"
#include "parameter-binding.h"
+#include "lower-to-ir.h"
#include "../slang/parser.h"
#include "../slang/preprocessor.h"
#include "../slang/reflection.h"
@@ -175,11 +176,46 @@ void CompileRequest::parseTranslationUnit(
void CompileRequest::checkAllTranslationUnits()
{
+ // Iterate over all translation units and
+ // apply the semantic checking logic.
for( auto& translationUnit : translationUnits )
{
checkTranslationUnit(translationUnit.Ptr());
}
}
+
+void CompileRequest::generateIR()
+{
+ // Our task in this function is to generate IR code
+ // for all of the declarations in the translation
+ // units that were loaded.
+
+ // At the moment, use of the IR is not enabled by
+ // default, so we will skip this step unless
+ // the flag was set to op in.
+ if (!(compileFlags & SLANG_COMPILE_FLAG_USE_IR))
+ return;
+
+ // Each translation unit is its own little world
+ // for code generation (we are not trying to
+ // replicate the GLSL linkage model), and so
+ // we will generate IR for each (if needed)
+ // in isolation.
+ for( auto& translationUnit : translationUnits )
+ {
+ // If the user opted out of semantic checking for
+ // the translation unit, then IR code generation
+ // is not in general even possible; there might
+ // be semantics errors (diagnosed or not) in the
+ // code, and we don't want to deal with those.
+ if (translationUnit->compileFlags & SLANG_COMPILE_FLAG_NO_CHECKING)
+ continue;
+
+ // Okay, we seem to be in the clear now.
+ translationUnit->irModule = generateIRForTranslationUnit(translationUnit);
+ }
+}
+
// Try to infer a single common source language for a request
static SourceLanguage inferSourceLanguage(CompileRequest* request)
{
@@ -223,18 +259,18 @@ int CompileRequest::executeActionsInner()
}
// If no code-generation target was specified, then try to infer one from the source language,
- // just to make sure we can do something reasonable when `reflection-json` is specified
- if (Target == CodeGenTarget::Unknown)
+ // just to make sure we can do something reasonable when invoked from the command line.
+ if (targets.Count() == 0)
{
auto language = inferSourceLanguage(this);
switch (language)
{
case SourceLanguage::HLSL:
- Target = CodeGenTarget::DXBytecodeAssembly;
+ addTarget(CodeGenTarget::DXBytecode);
break;
case SourceLanguage::GLSL:
- Target = CodeGenTarget::SPIRVAssembly;
+ addTarget(CodeGenTarget::SPIRV);
break;
default:
@@ -242,40 +278,6 @@ int CompileRequest::executeActionsInner()
}
}
-#if 0
- // If we are being asked to do pass-through, then we need to do that here...
- if (passThrough != PassThroughMode::None)
- {
- for (auto& translationUnitOptions : Options.translationUnits)
- {
- switch (translationUnitOptions.sourceLanguage)
- {
- // We can pass-through code written in a native shading language
- case SourceLanguage::GLSL:
- case SourceLanguage::HLSL:
- break;
-
- // All other translation units need to be skipped
- default:
- continue;
- }
-
- auto sourceFile = translationUnitOptions.sourceFiles[0];
- auto sourceFilePath = sourceFile->path;
- String source = sourceFile->content;
-
- auto translationUnitResult = passThrough(
- source,
- sourceFilePath,
- Options,
- translationUnitOptions);
-
- mResult.translationUnits.Add(translationUnitResult);
- }
- return 0;
- }
-#endif
-
// We only do parsing and semantic checking if we *aren't* doing
// a pass-through compilation.
//
@@ -295,17 +297,29 @@ int CompileRequest::executeActionsInner()
if (mSink.GetErrorCount() != 0)
return 1;
- // Now do shader parameter binding generation, which
- // needs to be performed globally.
- generateParameterBindings(this);
+ // Generate initial IR for all the translation
+ // units, if we are in a mode where IR is called for.
+ generateIR();
if (mSink.GetErrorCount() != 0)
return 1;
+
+ // For each code generation target generate
+ // parameter binding information.
+ // This step is done globaly, because all translation
+ // units and entry points need to agree on where
+ // parameters are allocated.
+ for (auto targetReq : targets)
+ {
+ generateParameterBindings(targetReq);
+ if (mSink.GetErrorCount() != 0)
+ return 1;
+ }
}
// If command line specifies to skip codegen, we exit here.
// Note: this is a debugging option.
-// if (shouldSkipCodegen)
-// return 0;
+ if (shouldSkipCodegen)
+ return 0;
// Generate output code, in whatever format was requested
generateOutput(this);
@@ -401,6 +415,19 @@ int CompileRequest::addEntryPoint(
return (int) result;
}
+UInt CompileRequest::addTarget(
+ CodeGenTarget target)
+{
+ RefPtr<TargetRequest> targetReq = new TargetRequest();
+ targetReq->compileRequest = this;
+ targetReq->target = target;
+
+ UInt result = targets.Count();
+ targets.Add(targetReq);
+ return (int) result;
+}
+
+
RefPtr<ModuleDecl> CompileRequest::loadModule(
Name* name,
String const& path,
@@ -713,9 +740,28 @@ SLANG_API void spSetCodeGenTarget(
SlangCompileRequest* request,
int target)
{
- REQ(request)->Target = (Slang::CodeGenTarget)target;
+ auto req = REQ(request);
+ req->targets.Clear();
+ req->addTarget(Slang::CodeGenTarget(target));
}
+SLANG_API void spAddCodeGenTarget(
+ SlangCompileRequest* request,
+ SlangCompileTarget target)
+{
+ auto req = REQ(request);
+ req->addTarget(Slang::CodeGenTarget(target));
+}
+
+SLANG_API void spSetOutputContainerFormat(
+ SlangCompileRequest* request,
+ SlangContainerFormat format)
+{
+ auto req = REQ(request);
+ req->containerFormat = Slang::ContainerFormat(format);
+}
+
+
SLANG_API void spSetPassThrough(
SlangCompileRequest* request,
SlangPassThrough passThrough)
@@ -918,16 +964,8 @@ SLANG_API char const* spGetTranslationUnitSource(
SlangCompileRequest* request,
int translationUnitIndex)
{
- auto req = REQ(request);
- return req->translationUnits[translationUnitIndex]->result.outputString.Buffer();
-}
-
-SLANG_API char const* spGetEntryPointSource(
- SlangCompileRequest* request,
- int entryPointIndex)
-{
- auto req = REQ(request);
- return req->entryPoints[entryPointIndex]->result.outputString.Buffer();
+ fprintf(stderr, "DEPRECATED: spGetTranslationUnitSource()\n");
+ return nullptr;
}
SLANG_API void const* spGetEntryPointCode(
@@ -936,7 +974,14 @@ SLANG_API void const* spGetEntryPointCode(
size_t* outSize)
{
auto req = REQ(request);
- Slang::CompileResult& result = req->entryPoints[entryPointIndex]->result;
+
+ // TODO: We should really accept a target index in this API
+ auto targetCount = req->targets.Count();
+ if (targetCount == 0)
+ return nullptr;
+ auto targetReq = req->targets[0];
+
+ Slang::CompileResult& result = targetReq->entryPointResults[entryPointIndex];
void const* data = nullptr;
size_t size = 0;
@@ -962,15 +1007,48 @@ SLANG_API void const* spGetEntryPointCode(
return data;
}
+SLANG_API char const* spGetEntryPointSource(
+ SlangCompileRequest* request,
+ int entryPointIndex)
+{
+ return (char const*) spGetEntryPointCode(request, entryPointIndex, nullptr);
+}
+
+SLANG_API void const* spGetCompileRequestCode(
+ SlangCompileRequest* request,
+ size_t* outSize)
+{
+ auto req = REQ(request);
+
+ void const* data = req->generatedBytecode.Buffer();
+ size_t size = req->generatedBytecode.Count();
+
+ if(outSize) *outSize = size;
+ return data;
+}
+
// Reflection API
SLANG_API SlangReflection* spGetReflection(
SlangCompileRequest* request)
{
if( !request ) return 0;
-
auto req = REQ(request);
- return (SlangReflection*) req->layout.Ptr();
+
+ // Note(tfoley): The API signature doesn't let the client
+ // specify which target they want to access reflection
+ // information for, so for now we default to the first one.
+ //
+ // TODO: Add a new `spGetReflectionForTarget(req, targetIndex)`
+ // so that we can do this better, and make it clear that
+ // `spGetReflection()` is shorthand for `targetIndex == 0`.
+ //
+ auto targetCount = req->targets.Count();
+ if (targetCount == 0)
+ return 0;
+ auto targetReq = req->targets[0];
+
+ return (SlangReflection*) targetReq->layout.Ptr();
}
// ... rest of reflection API implementation is in `Reflection.cpp`
diff --git a/source/slang/syntax-visitors.h b/source/slang/syntax-visitors.h
index 5c32c4a36..9644deae1 100644
--- a/source/slang/syntax-visitors.h
+++ b/source/slang/syntax-visitors.h
@@ -7,6 +7,7 @@
namespace Slang
{
class CompileRequest;
+ class EntryPointRequest;
class ShaderCompiler;
class ShaderLinkInfo;
class ShaderSymbol;
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index 034982ae4..19b9a435e 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -621,7 +621,6 @@ LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(CodeGenTarget target)
case CodeGenTarget::HLSL:
case CodeGenTarget::DXBytecode:
case CodeGenTarget::DXBytecodeAssembly:
- case CodeGenTarget::SlangIR:
return &kHLSLLayoutRulesFamilyImpl;
case CodeGenTarget::GLSL:
diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp
index d3e44a947..f129d15e0 100644
--- a/source/slang/vm.cpp
+++ b/source/slang/vm.cpp
@@ -670,7 +670,11 @@ VMModule* loadVMModuleInstance(
{
BCHeader* bcHeader = (BCHeader*) bytecode;
- BCModule* bcModule = bcHeader->module;
+ UInt bcModuleCount = bcHeader->moduleCount;
+ if (bcModuleCount == 0)
+ return nullptr;
+
+ BCModule* bcModule = bcHeader->modules[0];
UInt symbolCount = bcModule->symbolCount;
UInt typeCount = bcModule->typeCount;
diff --git a/source/slangc/main.cpp b/source/slangc/main.cpp
index 7c57fa8a9..391642c3e 100644
--- a/source/slangc/main.cpp
+++ b/source/slangc/main.cpp
@@ -66,25 +66,6 @@ int MAIN(int argc, char** argv)
exit(-1);
}
-#if 0
- // Produce output as the command-line compiler driver should.
-
- // Now dump the output from the compilation to stdout.
- //
- // TODO: Need a way to control where output goes so that
- // we can actually use the standalone compiler as something
- // more than a testing tool.
- //
-
- int translationUnitCount = spGetTranslationUnitCount(compileRequest);
- for(int tt = 0; tt < translationUnitCount; ++tt)
- {
- auto output = spGetTranslationUnitSource(compileRequest, tt);
- fputs(output, stdout);
- }
- fflush(stdout);
-#endif
-
// Now that we are done, clean up after ourselves
spDestroyCompileRequest(compileRequest);
diff --git a/tests/ir/loop.slang b/tests/ir/loop.slang
index acdba0b20..0342d914a 100644
--- a/tests/ir/loop.slang
+++ b/tests/ir/loop.slang
@@ -1,4 +1,4 @@
-//TEST:SIMPLE:-use-ir -dump-ir -skip-codegen -target hlsl -profile cs_5_0 -entry main
+//TEST:SIMPLE:-use-ir -dump-ir -profile cs_5_0 -entry main
#define GROUP_THREAD_COUNT 64
diff --git a/tools/eval-test/main.cpp b/tools/eval-test/main.cpp
index 9fb6f94a3..486de7bd9 100644
--- a/tools/eval-test/main.cpp
+++ b/tools/eval-test/main.cpp
@@ -37,9 +37,13 @@ int main(
SlangSession* session = spCreateSession(nullptr);
SlangCompileRequest* request = spCreateCompileRequest(session);
- spSetCodeGenTarget(
+ spSetCompileFlags(
request,
- SLANG_IR);
+ SLANG_COMPILE_FLAG_USE_IR);
+
+ spSetOutputContainerFormat(
+ request,
+ SLANG_CONTAINER_FORMAT_SLANG_MODULE);
int translationUnitIndex = spAddTranslationUnit(
request,
@@ -69,7 +73,7 @@ int main(
// Extract the bytecode
size_t bytecodeSize = 0;
- void const* bytecode = spGetEntryPointCode(request, entryPointIndex, &bytecodeSize);
+ void const* bytecode = spGetCompileRequestCode(request, &bytecodeSize);
// Now we need to create an execution context to go and run the bytecode we got
diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp
index dc8f11438..bc86809e6 100644
--- a/tools/render-test/render-d3d11.cpp
+++ b/tools/render-test/render-d3d11.cpp
@@ -68,261 +68,6 @@ static char const* fragmentProfileName = "ps_4_0";
ID3DBlob* gVertexShaderBlob;
ID3DBlob* gPixelShaderBlob;
-// Initialization when using HLSL for shaders
-HRESULT initializeHLSLInner(ID3D11Device* dxDevice, char const* sourcePath, char const* sourceText)
-{
- // Compile the generated HLSL code
- gVertexShaderBlob = compileHLSLShader(sourcePath, sourceText, vertexEntryPointName, vertexProfileName);
- if(!gVertexShaderBlob) return E_FAIL;
-
- gPixelShaderBlob = compileHLSLShader(sourcePath, sourceText, fragmentEntryPointName, fragmentProfileName);
- if(!gPixelShaderBlob) return E_FAIL;
-
-
- return S_OK;
-}
-
-// Initialization when using HLSL for shaders
-HRESULT initializeHLSL(ID3D11Device* dxDevice, char const* sourceText)
-{
- HRESULT hr = initializeHLSLInner(dxDevice, gOptions.sourcePath, sourceText);
- if(FAILED(hr))
- return hr;
-
- // TODO: any reflection stuff to do here?
-
- return S_OK;
-}
-
-// Initialization when using Slang for shaders
-HRESULT initializeSlang(ID3D11Device* dxDevice, char const* sourceText)
-{
- //
- // First, we will load and compile our Slang source code.
- //
-
- // The argument here is an optional directory where the Slang compiler
- // can cache files to speed up compilation of many kernels.
- SlangSession* slangSession = spCreateSession(NULL);
-
- // A compile request represents a single invocation of the compiler,
- // to process some inputs and produce outputs (or errors).
- SlangCompileRequest* slangRequest = spCreateCompileRequest(slangSession);
-
- // Instruct Slang to generate code as HLSL
- spSetCodeGenTarget(slangRequest, SLANG_HLSL);
-
- int translationUnitIndex = spAddTranslationUnit(slangRequest, SLANG_SOURCE_LANGUAGE_SLANG, nullptr);
-
- spAddTranslationUnitSourceString(slangRequest, translationUnitIndex, gOptions.sourcePath, sourceText);
-
- spAddEntryPoint(slangRequest, translationUnitIndex, vertexEntryPointName, spFindProfile(slangSession, vertexProfileName));
- spAddEntryPoint(slangRequest, translationUnitIndex, fragmentEntryPointName, spFindProfile(slangSession, fragmentProfileName));
-
- int compileErr = spCompile(slangRequest);
- if(auto diagnostics = spGetDiagnosticOutput(slangRequest))
- {
- OutputDebugStringA(diagnostics);
- fprintf(stderr, "%s", diagnostics);
- }
- if(compileErr)
- {
- return E_FAIL;
- }
-
- char const* translatedCode = spGetTranslationUnitSource(slangRequest, translationUnitIndex);
-
- // Compile the generated HLSL code
- HRESULT hr = initializeHLSLInner(dxDevice, "slangGeneratedCode", translatedCode);
- if(FAILED(hr))
- return hr;
-
- // We clean up the Slang compilation context and result *after*
- // we have done the HLSL-to-bytecode compilation, because Slang
- // owns the memory allocation for the generated HLSL, and will
- // free it when we destroy the compilation result.
- spDestroyCompileRequest(slangRequest);
- spDestroySession(slangSession);
-
- return S_OK;
-}
-
-#if 0
-
-//
-// At initialization time, we are going to load and compile our Slang shader
-// code, and then create the D3D11 API objects we need for rendering.
-//
-HRESULT initializeInner( ID3D11Device* dxDevice )
-{
- HRESULT hr = S_OK;
-
- // Read in the source code
- char const* sourcePath = gOptions.sourcePath;
- FILE* sourceFile = fopen(sourcePath, "rb");
- if( !sourceFile )
- {
- fprintf(stderr, "error: failed to open '%s' for reading\n", sourcePath);
- exit(1);
- }
- fseek(sourceFile, 0, SEEK_END);
- size_t sourceSize = ftell(sourceFile);
- fseek(sourceFile, 0, SEEK_SET);
- char* sourceText = (char*) malloc(sourceSize + 1);
- if( !sourceText )
- {
- fprintf(stderr, "error: out of memory");
- exit(1);
- }
- fread(sourceText, sourceSize, 1, sourceFile);
- fclose(sourceFile);
- sourceText[sourceSize] = 0;
-
- switch( gOptions.mode )
- {
- case Mode::HLSL:
- hr = initializeHLSL(dxDevice, sourceText);
- break;
-
- case Mode::Slang:
- hr = initializeSlang(dxDevice, sourceText);
- break;
-
- default:
- hr = E_FAIL;
- break;
- }
- if( FAILED(hr) )
- {
- return hr;
- }
-
- // Do other initialization that doesn't depend on the source language.
-
- // TODO(tfoley): use each API's reflection interface to query the constant-buffer size needed
- gConstantBufferSize = 16 * sizeof(float);
-
-
- D3D11_BUFFER_DESC dxConstantBufferDesc = { 0 };
- dxConstantBufferDesc.ByteWidth = gConstantBufferSize;
- dxConstantBufferDesc.Usage = D3D11_USAGE_DYNAMIC;
- dxConstantBufferDesc.BindFlags = D3D11_BIND_CONSTANT_BUFFER;
- dxConstantBufferDesc.CPUAccessFlags = D3D11_CPU_ACCESS_WRITE;
-
- hr = dxDevice->CreateBuffer(
- &dxConstantBufferDesc,
- NULL,
- &dxConstantBuffer);
- if(FAILED(hr)) return hr;
-
-
- // Input Assembler (IA)
-
- // In Slang-generated HLSL, all vertex shader inputs have a semantic
- // like: `A0`, `A1`, `A2`, etc., rather than trying to do by-name
- // matching. The user is thus responsibile for ensuring that the
- // order of their "input element descs" here matches the order
- // in which inputs are declared in the shader code.
- D3D11_INPUT_ELEMENT_DESC dxInputElements[] = {
- {"A", 0, DXGI_FORMAT_R32G32B32_FLOAT, 0, offsetof(Vertex, position), D3D11_INPUT_PER_VERTEX_DATA, 0 },
- {"A", 1, DXGI_FORMAT_R32G32B32_FLOAT, 0, offsetof(Vertex, color), D3D11_INPUT_PER_VERTEX_DATA, 0 },
- };
- hr = dxDevice->CreateInputLayout(
- &dxInputElements[0],
- 2,
- gVertexShaderBlob->GetBufferPointer(),
- gVertexShaderBlob->GetBufferSize(),
- &dxInputLayout);
- if(FAILED(hr)) return hr;
-
- D3D11_BUFFER_DESC dxVertexBufferDesc = { 0 };
- dxVertexBufferDesc.ByteWidth = kVertexCount * sizeof(Vertex);
- dxVertexBufferDesc.Usage = D3D11_USAGE_IMMUTABLE;
- dxVertexBufferDesc.BindFlags = D3D11_BIND_VERTEX_BUFFER;
-
- D3D11_SUBRESOURCE_DATA dxVertexBufferInitData = { 0 };
- dxVertexBufferInitData.pSysMem = &kVertexData[0];
-
- hr = dxDevice->CreateBuffer(
- &dxVertexBufferDesc,
- &dxVertexBufferInitData,
- &dxVertexBuffer);
- if(FAILED(hr)) return hr;
-
- // Vertex Shader (VS)
-
- hr = dxDevice->CreateVertexShader(
- gVertexShaderBlob->GetBufferPointer(),
- gVertexShaderBlob->GetBufferSize(),
- NULL,
- &dxVertexShader);
- gVertexShaderBlob->Release();
- if(FAILED(hr)) return hr;
-
- // Pixel Shader (PS)
-
- hr = dxDevice->CreatePixelShader(
- gPixelShaderBlob->GetBufferPointer(),
- gPixelShaderBlob->GetBufferSize(),
- NULL,
- &dxPixelShader);
- gPixelShaderBlob->Release();
- if(FAILED(hr)) return hr;
-
- return S_OK;
-}
-
-void renderFrameInner(ID3D11DeviceContext* dxContext)
-{
- // We update our constant buffer per-frame, just for the purposes
- // of the example, but we don't actually load different data
- // per-frame (we always use an identity projection).
- D3D11_MAPPED_SUBRESOURCE mapped;
- HRESULT hr = dxContext->Map(dxConstantBuffer, 0, D3D11_MAP_WRITE_DISCARD, 0, &mapped);
- if(!FAILED(hr))
- {
- float* data = (float*) mapped.pData;
-
- static const float kIdentity[] =
- { 1, 0, 0, 0,
- 0, 1, 0, 0,
- 0, 0, 1, 0,
- 0, 0, 0, 1 };
- memcpy(data, kIdentity, sizeof(kIdentity));
-
- dxContext->Unmap(dxConstantBuffer, 0);
- }
-
- // Input Assembler (IA)
-
- dxContext->IASetInputLayout(dxInputLayout);
- dxContext->IASetPrimitiveTopology(D3D11_PRIMITIVE_TOPOLOGY_TRIANGLELIST);
-
- UINT dxVertexStride = sizeof(Vertex);
- UINT dxVertexBufferOffset = 0;
- dxContext->IASetVertexBuffers(0, 1, &dxVertexBuffer, &dxVertexStride, &dxVertexBufferOffset);
-
- // Vertex Shader (VS)
-
- dxContext->VSSetShader(dxVertexShader, NULL, 0);
- dxContext->VSSetConstantBuffers(0, 1, &dxConstantBuffer);
-
- // Pixel Shader (PS)
-
- dxContext->PSSetShader(dxPixelShader, NULL, 0);
- dxContext->VSSetConstantBuffers(0, 1, &dxConstantBuffer);
-
- //
-
- dxContext->Draw(3, 0);
-}
-
-void finalize()
-{
-}
-
-#endif
-
//
// Definition of the HLSL-to-bytecode compilation logic.
//
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index dbddd9c4f..bf8b7b9c7 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -38,6 +38,8 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
int vertexTranslationUnit = 0;
int fragmentTranslationUnit = 0;
+ char const* vertexEntryPointName = request.vertexShader.name;
+ char const* fragmentEntryPointName = request.fragmentShader.name;
if( sourceLanguage == SLANG_SOURCE_LANGUAGE_GLSL )
{
// GLSL presents unique challenges because, frankly, it got the whole
@@ -48,13 +50,13 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
vertexTranslationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
spAddTranslationUnitSourceString(slangRequest, vertexTranslationUnit, request.source.path, request.source.text);
-
spTranslationUnit_addPreprocessorDefine(slangRequest, vertexTranslationUnit, "__GLSL_VERTEX__", "1");
+ vertexEntryPointName = "main";
fragmentTranslationUnit = spAddTranslationUnit(slangRequest, sourceLanguage, nullptr);
spAddTranslationUnitSourceString(slangRequest, fragmentTranslationUnit, request.source.path, request.source.text);
-
spTranslationUnit_addPreprocessorDefine(slangRequest, fragmentTranslationUnit, "__GLSL_FRAGMENT__", "1");
+ fragmentEntryPointName = "main";
}
else
{
@@ -72,8 +74,8 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING);
}
- int vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, request.vertexShader.name, spFindProfile(slangSession, request.vertexShader.profile));
- int fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, request.fragmentShader.name, spFindProfile(slangSession, request.fragmentShader.profile));
+ int vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile));
+ int fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile));
int compileErr = spCompile(slangRequest);
if(auto diagnostics = spGetDiagnosticOutput(slangRequest))
@@ -90,12 +92,6 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
ShaderCompileRequest innerRequest = request;
- if( sourceLanguage != SLANG_SOURCE_LANGUAGE_GLSL )
- {
- char const* translatedCode = spGetTranslationUnitSource(slangRequest, 0);
- innerRequest.source.text = translatedCode;
- }
-
char const* vertexCode = spGetEntryPointSource(slangRequest, vertexEntryPoint);
char const* fragmentCode = spGetEntryPointSource(slangRequest, fragmentEntryPoint);