SPIR-V: Constant buffer indexing support

This commit is contained in:
gdk 2022-04-07 11:45:04 -03:00 committed by riperiperi
parent 1448136c3d
commit d5e2cc2f9b
3 changed files with 69 additions and 16 deletions

View file

@ -20,6 +20,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public int InputVertices { get; } public int InputVertices { get; }
public Dictionary<int, Instruction> UniformBuffers { get; } = new Dictionary<int, Instruction>(); public Dictionary<int, Instruction> UniformBuffers { get; } = new Dictionary<int, Instruction>();
public Instruction UniformBuffersArray { get; set; }
public Instruction StorageBuffersArray { get; set; } public Instruction StorageBuffersArray { get; set; }
public Instruction LocalMemory { get; set; } public Instruction LocalMemory { get; set; }
public Instruction SharedMemory { get; set; } public Instruction SharedMemory { get; set; }
@ -332,12 +333,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
public Instruction GetConstantBuffer(AggregateType type, AstOperand operand) public Instruction GetConstantBuffer(AggregateType type, AstOperand operand)
{ {
var ubVariable = UniformBuffers[operand.CbufSlot]; var i1 = Constant(TypeS32(), 0);
var i0 = Constant(TypeS32(), 0); var i2 = Constant(TypeS32(), operand.CbufOffset >> 2);
var i1 = Constant(TypeS32(), operand.CbufOffset >> 2); var i3 = Constant(TypeU32(), operand.CbufOffset & 3);
var i2 = Constant(TypeU32(), operand.CbufOffset & 3);
Instruction elemPointer;
if (UniformBuffersArray != null)
{
var ubVariable = UniformBuffersArray;
var i0 = Constant(TypeS32(), operand.CbufSlot);
elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2, i3);
}
else
{
var ubVariable = UniformBuffers[operand.CbufSlot];
elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i1, i2, i3);
}
var elemPointer = AccessChain(TypePointer(StorageClass.Uniform, TypeFP32()), ubVariable, i0, i1, i2);
return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer)); return BitcastIfNeeded(type, AggregateType.FP32, Load(TypeFP32(), elemPointer));
} }

View file

@ -125,6 +125,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
private static void DeclareUniformBuffers(CodeGenContext context, BufferDescriptor[] descriptors) private static void DeclareUniformBuffers(CodeGenContext context, BufferDescriptor[] descriptors)
{ {
if (descriptors.Length == 0)
{
return;
}
uint ubSize = Constants.ConstantBufferSize / 16; uint ubSize = Constants.ConstantBufferSize / 16;
var ubArrayType = context.TypeArray(context.TypeVector(context.TypeFP32(), 4), context.Constant(context.TypeU32(), ubSize), true); var ubArrayType = context.TypeArray(context.TypeVector(context.TypeFP32(), 4), context.Constant(context.TypeU32(), ubSize), true);
@ -132,6 +137,24 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var ubStructType = context.TypeStruct(true, ubArrayType); var ubStructType = context.TypeStruct(true, ubArrayType);
context.Decorate(ubStructType, Decoration.Block); context.Decorate(ubStructType, Decoration.Block);
context.MemberDecorate(ubStructType, 0, Decoration.Offset, (LiteralInteger)0); context.MemberDecorate(ubStructType, 0, Decoration.Offset, (LiteralInteger)0);
if (context.Config.UsedFeatures.HasFlag(FeatureFlags.CbIndexing))
{
int count = descriptors.Max(x => x.Slot) + 1;
var ubStructArrayType = context.TypeArray(ubStructType, context.Constant(context.TypeU32(), count));
var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructArrayType);
var ubVariable = context.Variable(ubPointerType, StorageClass.Uniform);
context.Name(ubVariable, $"{GetStagePrefix(context.Config.Stage)}_u");
context.Decorate(ubVariable, Decoration.DescriptorSet, (LiteralInteger)0);
context.Decorate(ubVariable, Decoration.Binding, (LiteralInteger)context.Config.FirstConstantBufferBinding);
context.AddGlobalVariable(ubVariable);
context.UniformBuffersArray = ubVariable;
}
else
{
var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType); var ubPointerType = context.TypePointer(StorageClass.Uniform, ubStructType);
foreach (var descriptor in descriptors) foreach (var descriptor in descriptors)
@ -145,6 +168,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
context.UniformBuffers.Add(descriptor.Slot, ubVariable); context.UniformBuffers.Add(descriptor.Slot, ubVariable);
} }
} }
}
private static void DeclareStorageBuffers(CodeGenContext context, BufferDescriptor[] descriptors) private static void DeclareStorageBuffers(CodeGenContext context, BufferDescriptor[] descriptors)
{ {
@ -154,7 +178,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
} }
int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 1 : 0; int setIndex = context.Config.Options.TargetApi == TargetApi.Vulkan ? 1 : 0;
int count = descriptors.Max(x => x.Binding) + 1; int count = descriptors.Max(x => x.Slot) + 1;
var sbArrayType = context.TypeRuntimeArray(context.TypeU32()); var sbArrayType = context.TypeRuntimeArray(context.TypeU32());
context.Decorate(sbArrayType, Decoration.ArrayStride, (LiteralInteger)4); context.Decorate(sbArrayType, Decoration.ArrayStride, (LiteralInteger)4);

View file

@ -880,12 +880,26 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Spirv
var src1 = operation.GetSource(0); var src1 = operation.GetSource(0);
var src2 = context.Get(AggregateType.S32, operation.GetSource(1)); var src2 = context.Get(AggregateType.S32, operation.GetSource(1));
var ubVariable = context.UniformBuffers[((AstOperand)src1).Value]; var i1 = context.Constant(context.TypeS32(), 0);
var i0 = context.Constant(context.TypeS32(), 0); var i2 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2));
var i1 = context.ShiftRightArithmetic(context.TypeS32(), src2, context.Constant(context.TypeS32(), 2)); var i3 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3));
var i2 = context.BitwiseAnd(context.TypeS32(), src2, context.Constant(context.TypeS32(), 3));
SpvInstruction elemPointer;
if (context.UniformBuffersArray != null)
{
var ubVariable = context.UniformBuffersArray;
var i0 = context.Get(AggregateType.S32, src1);
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2, i3);
}
else
{
var ubVariable = context.UniformBuffers[((AstOperand)src1).Value];
elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i1, i2, i3);
}
var elemPointer = context.AccessChain(context.TypePointer(StorageClass.Uniform, context.TypeFP32()), ubVariable, i0, i1, i2);
var value = context.Load(context.TypeFP32(), elemPointer); var value = context.Load(context.TypeFP32(), elemPointer);
return new OperationResult(AggregateType.FP32, value); return new OperationResult(AggregateType.FP32, value);