From ffc642048e6c7b1d152a3b76fe918b4f9b68e76a Mon Sep 17 00:00:00 2001 From: Skylot Date: Thu, 18 Dec 2014 22:24:28 +0300 Subject: [PATCH] core: fix type check for loop over iterable. --- .../visitors/regions/LoopRegionVisitor.java | 44 +++++++++++-------- .../loops/TestIterableForEach3.java | 43 ++++++++++++++++++ 2 files changed, 69 insertions(+), 18 deletions(-) create mode 100644 jadx-core/src/test/java/jadx/tests/integration/loops/TestIterableForEach3.java diff --git a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java index bf84e9764..ab2df63af 100644 --- a/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java +++ b/jadx-core/src/main/java/jadx/core/dex/visitors/regions/LoopRegionVisitor.java @@ -224,7 +224,8 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor InsnArg iterableArg = assignInsn.getArg(0); InsnNode hasNextCall = useList.get(0).getParentInsn(); InsnNode nextCall = useList.get(1).getParentInsn(); - if (!checkInvoke(hasNextCall, "java.util.Iterator", "hasNext()Z", 0) + if (hasNextCall == null || nextCall == null + || !checkInvoke(hasNextCall, "java.util.Iterator", "hasNext()Z", 0) || !checkInvoke(nextCall, "java.util.Iterator", "next()Ljava/lang/Object;", 0)) { return false; } @@ -239,7 +240,7 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor } else { iterVar = parentInsn.getResult(); InsnArg castArg = BlockUtils.searchWrappedInsnParent(mth, parentInsn); - if (castArg != null) { + if (castArg != null && castArg.getParentInsn() != null) { castArg.getParentInsn().replaceArg(castArg, iterVar); } else { // cast not inlined @@ -266,27 +267,34 @@ public class LoopRegionVisitor extends AbstractVisitor implements IRegionVisitor } private static boolean fixIterableType(InsnArg iterableArg, RegisterArg iterVar) { - ArgType type = iterableArg.getType(); - if (type.isGeneric()) { - ArgType[] genericTypes = type.getGenericTypes(); - if (genericTypes != null && genericTypes.length == 1) { - ArgType gType = genericTypes[0]; - if (ArgType.isInstanceOf(gType, iterVar.getType())) { - return true; - } else { - LOG.warn("Generic type differs: {} and {}", type, iterVar.getType()); - } + ArgType iterableType = iterableArg.getType(); + ArgType varType = iterVar.getType(); + if (iterableType.isGeneric()) { + ArgType[] genericTypes = iterableType.getGenericTypes(); + if (genericTypes == null || genericTypes.length != 1) { + return false; } - } else { - if (!iterableArg.isRegister()) { + ArgType gType = genericTypes[0]; + if (gType.equals(varType)) { return true; } - // TODO: add checks - type = ArgType.generic(type.getObject(), new ArgType[]{iterVar.getType()}); - iterableArg.setType(type); + if (gType.isGenericType()) { + iterVar.setType(gType); + return true; + } + if (ArgType.isInstanceOf(gType, varType)) { + return true; + } + LOG.warn("Generic type differs: {} and {}", gType, varType); + return false; + } + if (!iterableArg.isRegister()) { return true; } - return false; + // TODO: add checks + iterableType = ArgType.generic(iterableType.getObject(), new ArgType[]{varType}); + iterableArg.setType(iterableType); + return true; } /** diff --git a/jadx-core/src/test/java/jadx/tests/integration/loops/TestIterableForEach3.java b/jadx-core/src/test/java/jadx/tests/integration/loops/TestIterableForEach3.java new file mode 100644 index 000000000..e84223493 --- /dev/null +++ b/jadx-core/src/test/java/jadx/tests/integration/loops/TestIterableForEach3.java @@ -0,0 +1,43 @@ +package jadx.tests.integration.loops; + +import jadx.core.dex.nodes.ClassNode; +import jadx.tests.api.IntegrationTest; + +import java.util.Set; + +import org.junit.Test; + +import static jadx.tests.api.utils.JadxMatchers.containsOne; +import static org.junit.Assert.assertThat; + +public class TestIterableForEach3 extends IntegrationTest { + + public static class TestCls { + private Set a; + private Set b; + + private void test(T str) { + Set set = str.length() == 1 ? a : b; + for (T s : set) { + if (s.length() == str.length()) { + if (str.length() == 0) { + set.remove(s); + } else { + set.add(str); + } + return; + } + } + } + } + + @Test + public void test() { + ClassNode cls = getClassNode(TestCls.class); + String code = cls.getCode().toString(); + + assertThat(code, containsOne("for (T s : set) {")); + assertThat(code, containsOne("if (str.length() == 0) {")); + // TODO move return outside 'if' + } +}