Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,20 @@
import com.google.errorprone.matchers.Matcher;
import com.google.errorprone.matchers.Matchers;
import com.google.errorprone.matchers.method.MethodMatchers;
import com.google.errorprone.predicates.TypePredicate;
import com.google.errorprone.predicates.TypePredicates;
import com.google.errorprone.util.ASTHelpers;
import com.sun.source.tree.ExpressionTree;
import com.sun.source.tree.ImportTree;
import com.sun.source.tree.MemberSelectTree;
import com.sun.source.tree.MethodInvocationTree;
import com.sun.source.tree.Tree;
import com.sun.source.util.SimpleTreeVisitor;
import com.sun.tools.javac.code.Type;
import java.util.List;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Collectors;

/**
* {@link PreferAssertj} provides an automated path from legacy test libraries to AssertJ. Our goal is to migrate
Expand Down Expand Up @@ -286,42 +292,69 @@ public Description matchMethodInvocation(MethodInvocationTree tree, VisitorState
if (ASSERT_EQUALS_FLOATING.matches(tree, state)) {
return withAssertThat(tree, state, (assertThat, fix) -> fix
.addStaticImport("org.assertj.core.api.Assertions.within")
.replace(tree, assertThat + "(" + argSource(tree, state, 1) + ").isCloseTo("
+ argSource(tree, state, 0) + ", within(" + argSource(tree, state, 2) + "))"));
.replace(tree, String.format("%s(%s)%s",
assertThat,
argSource(tree, state, 1),
isConstantZero(tree.getArguments().get(2))
? String.format(".isEqualTo(%s)", argSource(tree, state, 0))
: String.format(".isCloseTo(%s, within(%s))",
argSource(tree, state, 0), argSource(tree, state, 2)))));
}
if (ASSERT_EQUALS_FLOATING_DESCRIPTION.matches(tree, state)) {
return withAssertThat(tree, state, (assertThat, fix) -> fix
.addStaticImport("org.assertj.core.api.Assertions.within")
.replace(tree, assertThat + "(" + argSource(tree, state, 2)
+ ").describedAs(" + argSource(tree, state, 0) + ").isCloseTo("
+ argSource(tree, state, 1) + ", within(" + argSource(tree, state, 3) + "))"));
.replace(tree, String.format("%s(%s).describedAs(%s)%s",
assertThat,
argSource(tree, state, 2),
argSource(tree, state, 0),
isConstantZero(tree.getArguments().get(3))
? String.format(".isEqualTo(%s)", argSource(tree, state, 1))
: String.format(".isCloseTo(%s, within(%s))",
argSource(tree, state, 1), argSource(tree, state, 3)))));
}
if (ASSERT_NOT_EQUALS_FLOATING.matches(tree, state)) {
return withAssertThat(tree, state, (assertThat, fix) -> fix
.addStaticImport("org.assertj.core.api.Assertions.within")
.replace(tree, assertThat + "(" + argSource(tree, state, 1) + ").isNotCloseTo("
+ argSource(tree, state, 0) + ", within(" + argSource(tree, state, 2) + "))"));
.replace(tree, String.format("%s(%s)%s",
assertThat,
argSource(tree, state, 1),
isConstantZero(tree.getArguments().get(2))
? String.format(".isNotEqualTo(%s)", argSource(tree, state, 0))
: String.format(".isNotCloseTo(%s, within(%s))",
argSource(tree, state, 0), argSource(tree, state, 2)))));
}
if (ASSERT_NOT_EQUALS_FLOATING_DESCRIPTION.matches(tree, state)) {
return withAssertThat(tree, state, (assertThat, fix) -> fix
.addStaticImport("org.assertj.core.api.Assertions.within")
.replace(tree, assertThat + "(" + argSource(tree, state, 2)
+ ").describedAs(" + argSource(tree, state, 0) + ").isNotCloseTo("
+ argSource(tree, state, 1) + ", within(" + argSource(tree, state, 3) + "))"));
.replace(tree, String.format("%s(%s).describedAs(%s)%s",
assertThat,
argSource(tree, state, 2),
argSource(tree, state, 0),
isConstantZero(tree.getArguments().get(3))
? String.format(".isNotEqualTo(%s)", argSource(tree, state, 1))
: String.format(".isNotCloseTo(%s, within(%s))",
argSource(tree, state, 1), argSource(tree, state, 3)))));
}
if (ASSERT_THAT.matches(tree, state)) {
Optional<String> replacement = tree.getArguments().get(1)
.accept(HamcrestVisitor.INSTANCE, state);
return withAssertThat(tree, state, (assertThat, fix) ->
fix.replace(tree, assertThat + "(" + argSource(tree, state, 0) + ").is(new "
fix.replace(tree, assertThat + "(" + argSource(tree, state, 0) + ")"
+ replacement.orElseGet(() ->
".is(new "
+ SuggestedFixes.qualifyType(state, fix, "org.assertj.core.api.HamcrestCondition")
+ "<>("
+ argSource(tree, state, 1) + "))"));
+ argSource(tree, state, 1) + "))")));
}
if (ASSERT_THAT_DESCRIPTION.matches(tree, state)) {
Optional<String> replacement = tree.getArguments().get(2)
.accept(HamcrestVisitor.INSTANCE, state);
return withAssertThat(tree, state, (assertThat, fix) ->
fix.replace(tree, assertThat + "(" + argSource(tree, state, 1)
+ ").describedAs(" + argSource(tree, state, 0) + ").is(new "
+ ").describedAs(" + argSource(tree, state, 0) + ")"
+ replacement.orElseGet(() -> ".is(new "
+ SuggestedFixes.qualifyType(state, fix, "org.assertj.core.api.HamcrestCondition")
+ "<>(" + argSource(tree, state, 2) + "))"));
+ "<>(" + argSource(tree, state, 2) + "))")));
}
if (ASSERT_EQUALS_CATCHALL.matches(tree, state)) {
int parameters = tree.getArguments().size();
Expand Down Expand Up @@ -414,6 +447,11 @@ private static boolean useStaticAssertjImport(VisitorState state) {
return true;
}

private static boolean isConstantZero(Tree tree) {
Object constantValue = ASTHelpers.constValue(tree);
return constantValue instanceof Number && ((Number) constantValue).doubleValue() == 0D;
}

private static boolean isExpressionSameType(VisitorState state, MemberSelectTree memberSelectTree, String type) {
return ASTHelpers.isSameType(
ASTHelpers.getType(memberSelectTree.getExpression()),
Expand All @@ -427,4 +465,187 @@ private static String argSource(MethodInvocationTree invocation, VisitorState st
checkArgument(index < arguments.size(), "Index is out of bounds");
return checkNotNull(state.getSourceForNode(arguments.get(index)), "Failed to find argument source");
}

private static final class HamcrestVisitor extends SimpleTreeVisitor<Optional<String>, VisitorState> {
private static final HamcrestVisitor INSTANCE = new HamcrestVisitor(false);

private static final HamcrestVisitor NEGATED = new HamcrestVisitor(true);

private static final TypePredicate MATCHERS = new TypePredicate() {

private final TypePredicate matcherPredicate = TypePredicates.isDescendantOf("org.hamcrest.Matcher");
private final TypePredicate[] predicates = new TypePredicate[] {
TypePredicates.isExactType("org.hamcrest.Matchers"),
TypePredicates.isExactType("org.hamcrest.CoreMatchers"),
// Allows uses of direct imports to be migrated as well,
// e.g. 'org.hamcrest.core.Is.is'.
(TypePredicate) (type, state) -> matcherPredicate.apply(type, state)
// Limit to Hamcrest packages to avoid interaction with non-standard library code
&& type.toString().startsWith("org.hamcrest.")
};

@Override
public boolean apply(Type type, VisitorState state) {
for (TypePredicate predicate : predicates) {
if (predicate.apply(type, state)) {
return true;
}
}
return false;
}
};

private static final Matcher<ExpressionTree> IS_MATCHER = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("is")
.withParameters("org.hamcrest.Matcher");

private static final Matcher<ExpressionTree> NOT_MATCHER = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("not")
.withParameters("org.hamcrest.Matcher");

private static final Matcher<ExpressionTree> EQUALS = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("is", "equalTo", "equalToObject")
.withParameters(Object.class.getName());

private static final Matcher<ExpressionTree> INSTANCE_OF = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("isA", "instanceOf")
.withParameters("java.lang.Class");

private static final Matcher<ExpressionTree> NULL = Matchers.anyOf(
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("nullValue")
.withParameters(),
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("nullValue")
.withParameters("java.lang.Class"));

private static final Matcher<ExpressionTree> NOT_NULL = Matchers.anyOf(
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("notNullValue")
.withParameters(),
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("notNullValue")
.withParameters("java.lang.Class"));

private static final Matcher<ExpressionTree> CONTAINS = Matchers.anyOf(
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("hasItem", "hasItemInArray")
.withParameters(Object.class.getName()),
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.named("containsString")
.withParameters(String.class.getName()));

// Note: cannot match array/vararg arguments
private static final Matcher<ExpressionTree> HAS_ITEMS = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("hasItems", "arrayContainingInAnyOrder");

private static final Matcher<ExpressionTree> IS_EMPTY = Matchers.anyOf(
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("empty", "emptyIterable", "emptyArray", "anEmptyMap")
.withParameters(),
MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("emptyCollectionOf", "emptyIterableOf")
.withParameters(Class.class.getName()));

private static final Matcher<ExpressionTree> HAS_SIZE = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("hasSize", "aMapWithSize", "arrayWithSize", "iterableWithSize")
.withParameters(int.class.getName());

private static final Matcher<ExpressionTree> SAME_INSTANCE = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("sameInstance", "theInstance")
.withParameters(Object.class.getName());

private static final Matcher<ExpressionTree> STARTS_WITH = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("startsWith")
.withParameters(String.class.getName());

private static final Matcher<ExpressionTree> ENDS_WITH = MethodMatchers.staticMethod()
.onClass(MATCHERS)
.namedAnyOf("endsWith")
.withParameters(String.class.getName());

private final boolean negated;

private HamcrestVisitor(boolean negated) {
super(Optional.empty());
this.negated = negated;
}

@Override
@SuppressWarnings("CyclomaticComplexity")
public Optional<String> visitMethodInvocation(MethodInvocationTree node, VisitorState state) {
// is(Matcher) is for readability, fall through to the next matcher
if (IS_MATCHER.matches(node, state)) {
return node.getArguments().get(0).accept(this, state);
}
if (NOT_MATCHER.matches(node, state)) {
return node.getArguments().get(0).accept(this.negated ? INSTANCE : NEGATED, state);
}
if (EQUALS.matches(node, state)) {
return Optional.of((negated ? ".isNotEqualTo(" : ".isEqualTo(")
+ argSource(node, state, 0) + ")");
}
if (INSTANCE_OF.matches(node, state)) {
return Optional.of((negated ? ".isNotInstanceOf(" : ".isInstanceOf(")
+ argSource(node, state, 0) + ")");
}
if (NULL.matches(node, state)) {
return Optional.of(negated ? ".isNotNull()" : ".isNull()");
}
if (NOT_NULL.matches(node, state)) {
return Optional.of(negated ? ".isNull()" : ".isNotNull()");
}
if (CONTAINS.matches(node, state)) {
return Optional.of((negated ? ".doesNotContain(" : ".contains(") + argSource(node, state, 0) + ')');
}
if (IS_EMPTY.matches(node, state)) {
return Optional.of(negated ? ".isNotEmpty()" : ".isEmpty()");
}
if (SAME_INSTANCE.matches(node, state)) {
return Optional.of((negated ? ".isNotSameAs(" : ".isSameAs(") + argSource(node, state, 0) + ')');
}
if (STARTS_WITH.matches(node, state)) {
return Optional.of((negated ? ".doesNotStartWith(" : ".startsWith(") + argSource(node, state, 0) + ')');
}
if (ENDS_WITH.matches(node, state)) {
return Optional.of((negated ? ".doesNotEndWith(" : ".endsWith(") + argSource(node, state, 0) + ')');
}
if (HAS_SIZE.matches(node, state)) {
if (negated) {
// No top level method for negated size assertions
return Optional.empty();
}

return Optional.of(".hasSize(" + argSource(node, state, 0) + ')');
}
if (HAS_ITEMS.matches(node, state)
&& checkNotNull(ASTHelpers.getSymbol(node), "symbol").isVarArgs()) {
if (negated) {
// this negates to 'doesNotContainAll' which doesn't exist. AssertJ doesNotContain
// evaluates as 'does not contain any'.
return Optional.empty();
}
return Optional.of(".contains(" + node.getArguments().stream()
.map(state::getSourceForNode)
.collect(Collectors.joining(", ")) + ')');
}
return Optional.empty();
}
}
}
Loading