summaryrefslogtreecommitdiffstats
path: root/tools/gfx/cuda/cuda-helper-functions.h
blob: 2b2614bb05c63779262ba1ae9e9c911e7779a6e5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
// cuda-helper-functions.h
#pragma once

#include "../../../source/core/slang-list.h"
#include "cuda-base.h"
#include "slang-gfx.h"

namespace gfx
{
using namespace Slang;

#ifdef GFX_ENABLE_CUDA
namespace cuda
{
SLANG_FORCE_INLINE bool _isError(CUresult result)
{
    return result != 0;
}

// A enum used to control if errors are reported on failure of CUDA call.
enum class CUDAReportStyle
{
    Normal,
    Silent,
};

struct CUDAErrorInfo
{
    CUDAErrorInfo(
        const char* filePath,
        int lineNo,
        const char* errorName = nullptr,
        const char* errorString = nullptr)
        : m_filePath(filePath), m_lineNo(lineNo), m_errorName(errorName), m_errorString(errorString)
    {
    }
    SlangResult handle() const;

    const char* m_filePath;
    int m_lineNo;
    const char* m_errorName;
    const char* m_errorString;
};

// If this code path is enabled, CUDA errors will be reported directly to StdWriter::out stream.

SlangResult _handleCUDAError(CUresult cuResult, const char* file, int line);

#define SLANG_CUDA_HANDLE_ERROR(x) _handleCUDAError(x, __FILE__, __LINE__)

#define SLANG_CUDA_RETURN_ON_FAIL(x)              \
    {                                             \
        auto _res = x;                            \
        if (_isError(_res))                       \
            return SLANG_CUDA_HANDLE_ERROR(_res); \
    }

#define SLANG_CUDA_RETURN_WITH_REPORT_ON_FAIL(x, r)                                             \
    {                                                                                           \
        auto _res = x;                                                                          \
        if (_isError(_res))                                                                     \
        {                                                                                       \
            return (r == CUDAReportStyle::Normal) ? SLANG_CUDA_HANDLE_ERROR(_res) : SLANG_FAIL; \
        }                                                                                       \
    }

#define SLANG_CUDA_ASSERT_ON_FAIL(x)           \
    {                                          \
        auto _res = x;                         \
        if (_isError(_res))                    \
        {                                      \
            SLANG_ASSERT(!"Failed CUDA call"); \
        };                                     \
    }

#ifdef RENDER_TEST_OPTIX

bool _isError(OptixResult result);

#if 1
SlangResult _handleOptixError(OptixResult result, char const* file, int line);

#define SLANG_OPTIX_HANDLE_ERROR(RESULT) _handleOptixError(RESULT, __FILE__, __LINE__)
#else
#define SLANG_OPTIX_HANDLE_ERROR(RESULT) SLANG_FAIL
#endif

#define SLANG_OPTIX_RETURN_ON_FAIL(EXPR)           \
    do                                             \
    {                                              \
        auto _res = EXPR;                          \
        if (_isError(_res))                        \
            return SLANG_OPTIX_HANDLE_ERROR(_res); \
    } while (0)

void _optixLogCallback(unsigned int level, const char* tag, const char* message, void* userData);

#endif

AdapterLUID getAdapterLUID(int deviceIndex);

// Version-aware cuCtxCreate wrapper that works with both CUDA 12 and CUDA 13
inline CUresult createCudaContext(CUcontext* pctx, unsigned int flags, CUdevice dev)
{
#if CUDA_VERSION >= 13000
    // CUDA 13+ requires CUctxCreateParams
    CUctxCreateParams ctxCreateParams = {};
    return cuCtxCreate(pctx, &ctxCreateParams, flags, dev);
#else
    // CUDA 12 and earlier use the old signature
    return cuCtxCreate(pctx, flags, dev);
#endif
}

} // namespace cuda
#endif

Result SLANG_MCALL getCUDAAdapters(List<AdapterInfo>& outAdapters);

Result SLANG_MCALL createCUDADevice(const IDevice::Desc* desc, IDevice** outDevice);

} // namespace gfx