diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs index e6ef4933e..753273faa 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/CodeGenContext.cs @@ -20,6 +20,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv public int InputVertices { get; } public Dictionary UniformBuffers { get; } = new Dictionary(); + public Instruction UniformBuffersArray { get; set; } public Instruction StorageBuffersArray { get; set; } public Instruction LocalMemory { get; set; } public Instruction SharedMemory { get; set; } @@ -332,12 +333,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv public Instruction GetConstantBuffer(AggregateType type, AstOperand operand) { - var ubVariable = UniformBuffers[operand.CbufSlot]; - var i0 = Constant(TypeS32(), 0); - var i1 = Constant(TypeS32(), operand.CbufOffset >> 2); - var i2 = Constant(TypeU32(), operand.CbufOffset & 3); + var i1 = Constant(TypeS32(), 0); + var i2 = Constant(TypeS32(), operand.CbufOffset >> 2); + var i3 = Constant(TypeU32(), operand.CbufOffset & 3); + + Instruction elemPointer; + + if (UniformBuffersArray != null) + { + var ubVariable = UniformBuffersArray; + var i0 = Constant(TypeS32(), operand.CbufSlot); + + elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2, i3); + } + else + { + var ubVariable = UniformBuffers[operand.CbufSlot]; + + elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i1, i2, i3); + } - var elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2); return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer)); } diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs index f6c81a3a2..f3c16e919 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Declarations.cs @@ -125,6 +125,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv private static void DeclareUniformBuffers(CodeGenContext context, BufferDescriptor[] descriptors) { + if (descriptors.Length == 0) + { + return; + } + uint ubSize = Constants.ConstantBufferSize / 16; var ubArrayType = context.TypeArray(context.TypeVector(context.TypeFP32(), 4), context.Constant(context.TypeU32(), ubSize), true); @@ -132,17 +137,36 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var ubStructType = context.TypeStruct(true, ubArrayType); context.Decorate(ubStructType, Decoration.Block); context.MemberDecorate(ubStructType, 0, Decoration.Offset, (LiteralInteger)0); - var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType); - foreach (var descriptor in descriptors) + if (context.Config.UsedFeatures.HasFlag(FeatureFlags.CbIndexing)) { + int count = descriptors.Max(x => x.Slot) + 1; + + var ubStructArrayType = context.TypeArray(ubStructType, context.Constant(context.TypeU32(), count)); + var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructArrayType); var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform); - context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_c{descriptor.Slot}"); + context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_u"); context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)0); - context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)descriptor.Binding); + context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)context.Config.FirstConstantBufferBinding); context.AddGlobalVariable(ubVariable); - context.UniformBuffers.Add(descriptor.Slot, ubVariable); + + context.UniformBuffersArray = ubVariable; + } + else + { + var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType); + + foreach (var descriptor in descriptors) + { + var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform); + + context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_c{descriptor.Slot}"); + context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)0); + context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)descriptor.Binding); + context.AddGlobalVariable(ubVariable); + context.UniformBuffers.Add(descriptor.Slot, ubVariable); + } } } @@ -154,7 +178,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv } int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 1 : 0; - int count = descriptors.Max(x => x.Binding) + 1; + int count = descriptors.Max(x => x.Slot) + 1; var sbArrayType = context.TypeRuntimeArray(context.TypeU32()); context.Decorate(sbArrayType, Decoration.ArrayStride, (LiteralInteger)4); diff --git a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs index da7f17d0c..514722355 100644 --- a/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs +++ b/Ryujinx.Graphics.Shader/CodeGen/Spirv/Instructions.cs @@ -880,12 +880,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv var src1 = operation.GetSource(0); var src2 = context.Get(AggregateType.S32, operation.GetSource(1)); - var ubVariable = context.UniformBuffers[((AstOperand)src1).Value]; - var i0 = context.Constant(context.TypeS32(), 0); - var i1 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2)); - var i2 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3)); + var i1 = context.Constant(context.TypeS32(), 0); + var i2 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2)); + var i3 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3)); + + SpvInstruction elemPointer; + + if (context.UniformBuffersArray != null) + { + var ubVariable = context.UniformBuffersArray; + var i0 = context.Get(AggregateType.S32, src1); + + elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, i3); + } + else + { + var ubVariable = context.UniformBuffers[((AstOperand)src1).Value]; + + elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, i3); + } - var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2); var value = context.Load(context.TypeFP32(), elemPointer); return new OperationResult(AggregateType.FP32, value);