Fix some shader gen problems…

This commit is contained in:
Isaac Marovitz 2024-03-19 17:18:59 -04:00 committed by Evan Husted
parent dc4305f1cf
commit e2445990a5
3 changed files with 21 additions and 15 deletions

View file

@ -65,11 +65,11 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
if (stage == ShaderStage.Vertex) if (stage == ShaderStage.Vertex)
{ {
context.AppendLine("VertexOutput out;"); context.AppendLine("VertexOut out;");
} }
else if (stage == ShaderStage.Fragment) else if (stage == ShaderStage.Fragment)
{ {
context.AppendLine("FragmentOutput out;"); context.AppendLine("FragmentOut out;");
} }
foreach (AstOperand decl in function.Locals) foreach (AstOperand decl in function.Locals)
@ -120,17 +120,16 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
switch (context.Definitions.Stage) switch (context.Definitions.Stage)
{ {
case ShaderStage.Vertex: case ShaderStage.Vertex:
prefix = "Vertex"; context.AppendLine($"struct VertexIn");
break; break;
case ShaderStage.Fragment: case ShaderStage.Fragment:
prefix = "Fragment"; context.AppendLine($"struct VertexOut");
break; break;
case ShaderStage.Compute: case ShaderStage.Compute:
prefix = "Compute"; context.AppendLine($"struct ComputeIn");
break; break;
} }
context.AppendLine($"struct {prefix}In");
context.EnterScope(); context.EnterScope();
foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) foreach (var ioDefinition in inputs.OrderBy(x => x.Location))
@ -162,31 +161,38 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
switch (context.Definitions.Stage) switch (context.Definitions.Stage)
{ {
case ShaderStage.Vertex: case ShaderStage.Vertex:
prefix = "Vertex"; context.AppendLine($"struct VertexOut");
break; break;
case ShaderStage.Fragment: case ShaderStage.Fragment:
prefix = "Fragment"; context.AppendLine($"struct FragmentOut");
break; break;
case ShaderStage.Compute: case ShaderStage.Compute:
prefix = "Compute"; context.AppendLine($"struct ComputeOut");
break; break;
} }
context.AppendLine($"struct {prefix}Output");
context.EnterScope(); context.EnterScope();
foreach (var ioDefinition in inputs.OrderBy(x => x.Location)) foreach (var ioDefinition in inputs.OrderBy(x => x.Location))
{ {
string type = GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true)); string type = ioDefinition.IoVariable switch
{
IoVariable.Position => "float4",
IoVariable.PointSize => "float",
_ => GetVarTypeName(context, context.Definitions.GetUserDefinedType(ioDefinition.Location, isOutput: true))
};
string name = ioDefinition.IoVariable switch string name = ioDefinition.IoVariable switch
{ {
IoVariable.Position => "position", IoVariable.Position => "position",
IoVariable.PointSize => "point_size",
IoVariable.FragmentOutputColor => "color", IoVariable.FragmentOutputColor => "color",
_ => $"{DefaultNames.OAttributePrefix}{ioDefinition.Location}" _ => $"{DefaultNames.OAttributePrefix}{ioDefinition.Location}"
}; };
string suffix = ioDefinition.IoVariable switch string suffix = ioDefinition.IoVariable switch
{ {
IoVariable.Position => " [[position]]", IoVariable.Position => " [[position]]",
IoVariable.PointSize => " [[point_size]]",
IoVariable.FragmentOutputColor => $" [[color({ioDefinition.Location})]]",
_ => "" _ => ""
}; };

View file

@ -24,7 +24,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl.Instructions
IoVariable.FrontFacing => ("front_facing", AggregateType.Bool), IoVariable.FrontFacing => ("front_facing", AggregateType.Bool),
IoVariable.InstanceId => ("instance_id", AggregateType.S32), IoVariable.InstanceId => ("instance_id", AggregateType.S32),
IoVariable.PointCoord => ("point_coord", AggregateType.Vector2), IoVariable.PointCoord => ("point_coord", AggregateType.Vector2),
IoVariable.PointSize => ("point_size", AggregateType.FP32), IoVariable.PointSize => ("out.point_size", AggregateType.FP32),
IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32), IoVariable.Position => ("out.position", AggregateType.Vector4 | AggregateType.FP32),
IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32), IoVariable.PrimitiveId => ("primitive_id", AggregateType.S32),
IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch), IoVariable.UserDefined => GetUserDefinedVariableName(definitions, location, component, isOutput, isPerPatch),

View file

@ -85,13 +85,13 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
{ {
funcKeyword = "vertex"; funcKeyword = "vertex";
funcName = "vertexMain"; funcName = "vertexMain";
returnType = "VertexOutput"; returnType = "VertexOut";
} }
else if (stage == ShaderStage.Fragment) else if (stage == ShaderStage.Fragment)
{ {
funcKeyword = "fragment"; funcKeyword = "fragment";
funcName = "fragmentMain"; funcName = "fragmentMain";
returnType = "FragmentOutput"; returnType = "FragmentOut";
} }
else if (stage == ShaderStage.Compute) else if (stage == ShaderStage.Compute)
{ {
@ -106,7 +106,7 @@ namespace Ryujinx.Graphics.Shader.CodeGen.Msl
} }
else if (stage == ShaderStage.Fragment) else if (stage == ShaderStage.Fragment)
{ {
args = args.Prepend("FragmentIn in").ToArray(); args = args.Prepend("VertexOut in").ToArray();
} }
else if (stage == ShaderStage.Compute) else if (stage == ShaderStage.Compute)
{ {