Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 33 additions & 5 deletions docs/Rules/MA0095.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,43 @@
Source: [EqualityShouldBeCorrectlyImplementedAnalyzer.cs](https://github.com/meziantou/Meziantou.Analyzer/blob/main/src/Meziantou.Analyzer/Rules/EqualityShouldBeCorrectlyImplementedAnalyzer.cs)
<!-- sources -->

When a type directly implements `IEquatable<T>`, it should also override `Equals(object)` to ensure consistent equality behavior across different contexts.

## Non-compliant Code

````c#
// non-compliant
public sealed class Test : IEquatable<Test>
{
public bool Equals(Test? other) => throw null;
}
````

````c#
class Test : IEquatable<T> // non-compliant
public class Base : IEquatable<Base>
{
public bool Equals(Test other) => throw null;
public bool Equals(Base? other) => throw null;
public override bool Equals(object? obj) => throw null;
public override int GetHashCode() => 0;
}

class Test : IEquatable<T> // ok
// non-compliant
public sealed class Derived : Base, IEquatable<Derived>
{
public bool Equals(Derived? other) => throw null;
// Missing override of Equals(object)
}
````

## Compliant Code

Direct implementation with `Equals(object)` override:

````c#
public sealed class Test : IEquatable<Test>
{
public override bool Equals(object other) => throw null;
public bool Equals(Test other) => throw null;
public override bool Equals(object? obj) => obj is Test other && Equals(other);
public bool Equals(Test? other) => throw null;
public override int GetHashCode() => 0;
}
````
6 changes: 3 additions & 3 deletions global.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"rollForward": "latestPatch"
},
"msbuild-sdks": {
"Meziantou.NET.Sdk": "1.0.38",
"Meziantou.NET.Sdk.Test": "1.0.38",
"Meziantou.NET.Sdk.Web": "1.0.38"
"Meziantou.NET.Sdk": "1.0.40",
"Meziantou.NET.Sdk.Test": "1.0.40",
"Meziantou.NET.Sdk.Web": "1.0.40"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ public override void Initialize(AnalysisContext context)

private sealed class AnalyzerContext(Compilation compilation)
{
public IMethodSymbol? ObjectEqualsSymbol { get; set; } = compilation.GetSpecialType(SpecialType.System_Object).GetMembers("Equals").FirstOrDefault() as IMethodSymbol;
public IMethodSymbol? ValueTypeEqualsSymbol { get; set; } = compilation.GetSpecialType(SpecialType.System_ValueType).GetMembers("Equals").FirstOrDefault() as IMethodSymbol;

public INamedTypeSymbol? IComparableSymbol { get; set; } = compilation.GetBestTypeByMetadataName("System.IComparable");
public INamedTypeSymbol? IComparableOfTSymbol { get; set; } = compilation.GetBestTypeByMetadataName("System.IComparable`1");
public INamedTypeSymbol? IEquatableOfTSymbol { get; set; } = compilation.GetBestTypeByMetadataName("System.IEquatable`1");
Expand All @@ -96,6 +99,7 @@ public void AnalyzeSymbol(SymbolAnalysisContext context)
var implementIComparable = false;
var implementIComparableOfT = false;
var implementIEquatableOfT = false;
var directlyImplementIEquatableOfT = false;
foreach (var implementedInterface in symbol.AllInterfaces)
{
if (implementedInterface.IsEqualTo(IComparableSymbol))
Expand All @@ -112,6 +116,16 @@ public void AnalyzeSymbol(SymbolAnalysisContext context)
}
}

// Check if the type directly implements IEquatable<T> (not inherited from base class)
foreach (var implementedInterface in symbol.Interfaces)
{
if (IEquatableOfTSymbol is not null && implementedInterface.IsEqualTo(IEquatableOfTSymbol.Construct(symbol)))
{
directlyImplementIEquatableOfT = true;
break;
}
}

// IComparable without IComparable<T>
if (implementIComparable && !implementIComparableOfT)
{
Expand All @@ -125,7 +139,9 @@ public void AnalyzeSymbol(SymbolAnalysisContext context)
}

// IEquatable<T> without Equals(object)
if (implementIEquatableOfT && !HasMethod(symbol, IsEqualsMethod))
// Only report if directly implemented (not inherited via CRTP - Curiously Recurring Template Pattern)
// Check the entire type hierarchy for an Equals(object) override
if (directlyImplementIEquatableOfT && !symbol.GetMembers().OfType<IMethodSymbol>().Any(IsEqualsMethodOverride))
{
context.ReportDiagnostic(OverrideEqualsObjectRule, symbol);
}
Expand All @@ -148,22 +164,32 @@ public void AnalyzeSymbol(SymbolAnalysisContext context)
context.ReportDiagnostic(AddComparisonRule, symbol);
}
}
}

private static bool HasMethod(INamedTypeSymbol parentType, Func<IMethodSymbol, bool> predicate)
{
foreach (var member in parentType.GetMembers().OfType<IMethodSymbol>())
private static bool HasMethod(INamedTypeSymbol parentType, Func<IMethodSymbol, bool> predicate)
{
if (predicate(member))
return true;
foreach (var member in parentType.GetMembers().OfType<IMethodSymbol>())
{
if (predicate(member))
return true;
}

return false;
}

return false;
}
private static bool HasMethodInHierarchy(INamedTypeSymbol type, Func<IMethodSymbol, bool> predicate)
{
foreach (var member in type.GetAllMembers().OfType<IMethodSymbol>())
{
if (predicate(member))
return true;
}

private static bool HasComparisonOperator(INamedTypeSymbol parentType)
{
var operatorNames = new List<string>(6)
return false;
}

private static bool HasComparisonOperator(INamedTypeSymbol parentType)
{
var operatorNames = new List<string>(6)
{
"op_LessThan",
"op_LessThanOrEqual",
Expand All @@ -173,44 +199,58 @@ private static bool HasComparisonOperator(INamedTypeSymbol parentType)
"op_Inequality",
};

foreach (var member in parentType.GetAllMembers().OfType<IMethodSymbol>())
{
if (member.MethodKind is MethodKind.UserDefinedOperator)
foreach (var member in parentType.GetAllMembers().OfType<IMethodSymbol>())
{
operatorNames.Remove(member.Name);
if (member.MethodKind is MethodKind.UserDefinedOperator)
{
operatorNames.Remove(member.Name);
}
}

return operatorNames.Count == 0;
}

return operatorNames.Count == 0;
}
private static bool IsEqualsMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(object.Equals) &&
symbol.ReturnType.IsBoolean() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsObject() &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
}

private static bool IsEqualsMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(object.Equals) &&
symbol.ReturnType.IsBoolean() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsObject() &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
}
private bool IsEqualsMethodOverride(IMethodSymbol symbol)
{
// Check if it's an Equals(object) method AND it's overridden (not the base System.Object method)
return symbol.Name == nameof(object.Equals) &&
symbol.ReturnType.IsBoolean() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsObject() &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic &&
!symbol.IsEqualTo(ObjectEqualsSymbol) &&
!symbol.IsEqualTo(ValueTypeEqualsSymbol);
}

private static bool IsCompareToMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(IComparable.CompareTo) &&
symbol.ReturnType.IsInt32() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsObject() &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
}
private static bool IsCompareToMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(IComparable.CompareTo) &&
symbol.ReturnType.IsInt32() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsObject() &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
}

private static bool IsCompareToOfTMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(IComparable.CompareTo) &&
symbol.ReturnType.IsInt32() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsEqualTo(symbol.ContainingType) &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
private static bool IsCompareToOfTMethod(IMethodSymbol symbol)
{
return symbol.Name == nameof(IComparable.CompareTo) &&
symbol.ReturnType.IsInt32() &&
symbol.Parameters.Length == 1 &&
symbol.Parameters[0].Type.IsEqualTo(symbol.ContainingType) &&
symbol.DeclaredAccessibility == Accessibility.Public &&
!symbol.IsStatic;
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
using Meziantou.Analyzer.Rules;
using TestHelper;

namespace Meziantou.Analyzer.Test.Rules;

public sealed class EqualityShouldBeCorrectlyImplementedAnalyzerMA0095Tests
{
private static ProjectBuilder CreateProjectBuilder()
{
return new ProjectBuilder()
.WithAnalyzer<EqualityShouldBeCorrectlyImplementedAnalyzer>()
.WithCodeFixProvider<EqualityShouldBeCorrectlyImplementedFixer>();
}

[Fact]
public async Task DirectImplementation_WithoutEqualsObject_ShouldTrigger()
{
var originalCode = """
using System;

public sealed class [|TriggersMA0095AndCA1067|] : IEquatable<TriggersMA0095AndCA1067>
{
public bool Equals(TriggersMA0095AndCA1067? other) => true;
}
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task DirectImplementation_WithEqualsObject_ShouldNotTrigger()
{
var originalCode = """
using System;

public sealed class Test : IEquatable<Test>
{
public bool Equals(Test? other) => true;
public override bool Equals(object? obj) => true;
public override int GetHashCode() => 0;
}
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

#if ROSLYN_4_8_OR_GREATER
[Fact]
public async Task CRTP_WithoutEqualsObject_ShouldNotTrigger()
{
var originalCode = """
using System;

public abstract class Crtp<T> : IEquatable<T> where T : Crtp<T>
{
public bool Equals(T? other) => true;
}

public sealed class TriggersMA0095Only : Crtp<TriggersMA0095Only>;
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task CRTP_WithEqualsObjectInBase_ShouldNotTrigger()
{
var originalCode = """
using System;

public abstract class Crtp<T> : IEquatable<T> where T : Crtp<T>
{
public bool Equals(T? other) => true;
public override bool Equals(object? obj) => true;
public override int GetHashCode() => 0;
}

public sealed class DerivedClass : Crtp<DerivedClass>;
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}
#endif

[Fact]
public async Task InheritedIEquatable_WithDirectImplementationToo_ShouldTrigger()
{
var originalCode = """
using System;

public abstract class Base : IEquatable<Base>
{
public bool Equals(Base? other) => true;
public override bool Equals(object? obj) => true;
public override int GetHashCode() => 0;
}

public sealed class [|Derived|] : Base, IEquatable<Derived>
{
public bool Equals(Derived? other) => true;
}
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}

[Fact]
public async Task Struct_DirectImplementation_WithoutEqualsObject_ShouldTrigger()
{
var originalCode = """
using System;

public struct [|TestStruct|] : IEquatable<TestStruct>
{
public bool Equals(TestStruct other) => true;
}
""";

await CreateProjectBuilder()
.WithSourceCode(originalCode)
.ValidateAsync();
}
}