fix: adjust types for arithmetic instructions (#921)

This commit is contained in:
Skylot
2020-09-11 19:29:55 +01:00
parent 50cfa4c971
commit 60b2353afe
9 changed files with 178 additions and 123 deletions
@@ -1,5 +1,7 @@
package jadx.core.dex.instructions;
import org.jetbrains.annotations.Nullable;
import jadx.api.plugins.input.insns.InsnData;
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.args.ArgType;
@@ -8,41 +10,54 @@ import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
import jadx.core.dex.nodes.InsnNode;
import jadx.core.utils.InsnUtils;
import jadx.core.utils.exceptions.JadxRuntimeException;
public class ArithNode extends InsnNode {
private final ArithOp op;
public ArithNode(InsnData insn, ArithOp op, ArgType type, boolean literal) {
super(InsnType.ARITH, 2);
this.op = op;
setResult(InsnArg.reg(insn, 0, type));
int rc = insn.getRegsCount();
if (literal) {
if (rc == 1) {
// self
addReg(insn, 0, type);
addLit(insn, type);
} else if (rc == 2) {
// normal
addReg(insn, 1, type);
addLit(insn, type);
}
} else {
if (rc == 2) {
// self
addReg(insn, 0, type);
addReg(insn, 1, type);
} else if (rc == 3) {
// normal
addReg(insn, 1, type);
addReg(insn, 2, type);
}
public static ArithNode build(InsnData insn, ArithOp op, ArgType type) {
RegisterArg resArg = InsnArg.reg(insn, 0, fixResultType(op, type));
ArgType argType = fixArgType(op, type);
switch (insn.getRegsCount()) {
case 2:
return new ArithNode(op, resArg, InsnArg.reg(insn, 0, argType), InsnArg.reg(insn, 1, argType));
case 3:
return new ArithNode(op, resArg, InsnArg.reg(insn, 1, argType), InsnArg.reg(insn, 2, argType));
default:
throw new JadxRuntimeException("Unexpected registers count in " + insn);
}
}
public ArithNode(ArithOp op, RegisterArg res, InsnArg a, InsnArg b) {
public static ArithNode buildLit(InsnData insn, ArithOp op, ArgType type) {
RegisterArg resArg = InsnArg.reg(insn, 0, fixResultType(op, type));
ArgType argType = fixArgType(op, type);
LiteralArg litArg = InsnArg.lit(insn, argType);
switch (insn.getRegsCount()) {
case 1:
return new ArithNode(op, resArg, InsnArg.reg(insn, 0, argType), litArg);
case 2:
return new ArithNode(op, resArg, InsnArg.reg(insn, 1, argType), litArg);
default:
throw new JadxRuntimeException("Unexpected registers count in " + insn);
}
}
private static ArgType fixResultType(ArithOp op, ArgType type) {
if (type == ArgType.INT && op.isBitOp()) {
return ArgType.INT_BOOLEAN;
}
return type;
}
private static ArgType fixArgType(ArithOp op, ArgType type) {
if (type == ArgType.INT && op.isBitOp()) {
return ArgType.NARROW_NUMBERS_NO_FLOAT;
}
return type;
}
private final ArithOp op;
public ArithNode(ArithOp op, @Nullable RegisterArg res, InsnArg a, InsnArg b) {
super(InsnType.ARITH, 2);
this.op = op;
setResult(res);
@@ -50,10 +65,6 @@ public class ArithNode extends InsnNode {
addArg(b);
}
public ArithNode(ArithOp op, InsnArg a, InsnArg b) {
this(op, null, a, b);
}
/**
* Create one argument arithmetic instructions (a+=2).
* Result is not set (null).
@@ -61,7 +72,7 @@ public class ArithNode extends InsnNode {
* @param res argument to change
*/
public static ArithNode oneArgOp(ArithOp op, InsnArg res, InsnArg a) {
ArithNode insn = new ArithNode(op, res, a);
ArithNode insn = new ArithNode(op, null, res, a);
insn.add(AFlag.ARITH_ONEARG);
return insn;
}
@@ -100,7 +111,7 @@ public class ArithNode extends InsnNode {
@Override
public InsnNode copy() {
ArithNode copy = new ArithNode(op, getArg(0).duplicate(), getArg(1).duplicate());
ArithNode copy = new ArithNode(op, null, getArg(0).duplicate(), getArg(1).duplicate());
return copyCommonParams(copy);
}
@@ -24,4 +24,15 @@ public enum ArithOp {
public String getSymbol() {
return this.symbol;
}
public boolean isBitOp() {
switch (this) {
case AND:
case OR:
case XOR:
return true;
default:
return false;
}
}
}
@@ -542,19 +542,11 @@ public class InsnDecoder {
}
private InsnNode arith(InsnData insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, fixTypeForBitOps(op, type), false);
return ArithNode.build(insn, op, type);
}
private InsnNode arithLit(InsnData insn, ArithOp op, ArgType type) {
return new ArithNode(insn, op, fixTypeForBitOps(op, type), true);
}
private ArgType fixTypeForBitOps(ArithOp op, ArgType type) {
if (type == ArgType.INT
&& (op == ArithOp.AND || op == ArithOp.OR || op == ArithOp.XOR)) {
return ArgType.NARROW_NUMBERS_NO_FLOAT;
}
return type;
return ArithNode.buildLit(insn, op, type);
}
private InsnNode neg(InsnData insn, ArgType type) {
@@ -65,6 +65,7 @@ public abstract class ArgType {
public static final ArgType WIDE = unknown(PrimitiveType.LONG, PrimitiveType.DOUBLE);
public static final ArgType INT_FLOAT = unknown(PrimitiveType.INT, PrimitiveType.FLOAT);
public static final ArgType INT_BOOLEAN = unknown(PrimitiveType.INT, PrimitiveType.BOOLEAN);
protected int hash;
@@ -14,6 +14,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import jadx.core.Consts;
import jadx.core.dex.instructions.ArithNode;
import jadx.core.dex.instructions.BaseInvokeNode;
import jadx.core.dex.instructions.IndexInsnNode;
import jadx.core.dex.instructions.InsnType;
@@ -286,7 +287,7 @@ public final class TypeUpdate {
registry.put(InsnType.AGET, this::arrayGetListener);
registry.put(InsnType.APUT, this::arrayPutListener);
registry.put(InsnType.IF, this::ifListener);
registry.put(InsnType.ARITH, this::suggestAllSameListener);
registry.put(InsnType.ARITH, this::arithListener);
registry.put(InsnType.NEG, this::suggestAllSameListener);
registry.put(InsnType.NOT, this::suggestAllSameListener);
registry.put(InsnType.CHECK_CAST, this::checkCastListener);
@@ -441,12 +442,24 @@ public final class TypeUpdate {
return allSame ? SAME : CHANGED;
}
private TypeUpdateResult arithListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
ArithNode arithInsn = (ArithNode) insn;
if (candidateType == ArgType.BOOLEAN && arithInsn.getOp().isBitOp()) {
// force all args to boolean
return allSameListener(updateInfo, insn, arg, candidateType);
}
return suggestAllSameListener(updateInfo, insn, arg, candidateType);
}
/**
* Try to set candidate type to all args, don't fail on reject
*/
private TypeUpdateResult suggestAllSameListener(TypeUpdateInfo updateInfo, InsnNode insn, InsnArg arg, ArgType candidateType) {
if (!isAssign(insn, arg)) {
updateTypeChecked(updateInfo, insn.getResult(), candidateType);
RegisterArg resultArg = insn.getResult();
if (resultArg != null) {
updateTypeChecked(updateInfo, resultArg, candidateType);
}
}
boolean allSame = true;
for (InsnArg insnArg : insn.getArguments()) {
@@ -0,0 +1,37 @@
package jadx.tests.integration.arith;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestArith4 extends IntegrationTest {
public static class TestCls {
public static byte test(byte b) {
int k = b & 7;
return (byte) (((b & 255) >>> (8 - k)) | (b << k));
}
public static int test2(String str) {
int k = 'a' | str.charAt(0);
return (1 - k) & (1 + k);
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("int k = b & 7;")
.containsOne("return (1 - k) & (k + 1);");
}
@Test
public void testNoDebug() {
noDebugInfo();
assertThat(getClassNode(TestCls.class))
.code();
}
}
@@ -0,0 +1,63 @@
package jadx.tests.integration.arith;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestXor extends SmaliTest {
@SuppressWarnings("PointlessBooleanExpression")
public static class TestCls {
public boolean test1() {
return test() ^ true;
}
public boolean test2(boolean v) {
return v ^ true;
}
public boolean test() {
return true;
}
public void check() {
assertThat(test1()).isFalse();
assertThat(test2(true)).isFalse();
assertThat(test2(false)).isTrue();
}
}
@Test
public void test() {
assertThat(getClassNode(TestCls.class))
.code()
.containsOne("return !test();")
.containsOne("return !v;");
}
@Test
public void smali() {
// @formatter:off
/*
public boolean test1() {
return test() ^ true;
}
public boolean test2() {
return test() ^ false;
}
public boolean test() {
return true;
}
*/
// @formatter:on
assertThat(getClassNodeFromSmali())
.code()
.containsOne("return !test();")
.containsOne("return test();");
}
}
@@ -1,61 +0,0 @@
package jadx.tests.integration.conditions;
import org.junit.jupiter.api.Test;
import jadx.core.dex.nodes.ClassNode;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.JadxMatchers.containsOne;
import static org.hamcrest.MatcherAssert.assertThat;
public class TestXor extends SmaliTest {
public static class TestCls {
public boolean test1() {
return test() ^ true;
}
public boolean test2(boolean v) {
return v ^ true;
}
public boolean test() {
return true;
}
}
@Test
public void test() {
ClassNode cls = getClassNode(TestCls.class);
String code = cls.getCode().toString();
assertThat(code, containsOne("return !test();"));
assertThat(code, containsOne("return !v;"));
}
@Test
public void smali() {
// @formatter:off
/*
public boolean test1() {
return test() ^ true;
}
public boolean test2() {
return test() ^ false;
}
public boolean test() {
return true;
}
*/
// @formatter:on
ClassNode cls = getClassNodeFromSmaliWithPath("conditions", "TestXor");
String code = cls.getCode().toString();
assertThat(code, containsOne("return !test();"));
assertThat(code, containsOne("return test();"));
}
}
@@ -1,19 +1,7 @@
.class public LTestXor;
.class public Larith/TestXor;
.super Ljava/lang/Object;
# direct methods
.method public constructor <init>()V
.locals 0
.line 9
invoke-direct {p0}, Ljava/lang/Object;-><init>()V
return-void
.end method
# virtual methods
.method public test()Z
.locals 1
@@ -27,7 +15,7 @@
.locals 1
.line 12
invoke-virtual {p0}, Lcom/example/myapplication/MainActivity;->test()Z
invoke-virtual {p0}, Larith/TestXor;->test()Z
move-result v0
@@ -40,7 +28,7 @@
.locals 1
.line 16
invoke-virtual {p0}, Lcom/example/myapplication/MainActivity;->test()Z
invoke-virtual {p0}, Larith/TestXor;->test()Z
move-result v0