fix: use correct args copy/replace in wrapped insns (#2835)

This commit is contained in:
Skylot
2026-04-01 20:19:56 +01:00
parent 9a8a11619b
commit 325b3ac991
10 changed files with 144 additions and 37 deletions
@@ -35,6 +35,6 @@ public final class ConstStringNode extends InsnNode {
@Override
public String toString() {
return super.toString() + ' ' + StringUtils.getInstance().unescapeString(str);
return super.baseString() + StringUtils.getInstance().unescapeString(str) + super.attributesString();
}
}
@@ -84,7 +84,7 @@ public final class InsnWrapArg extends InsnArg {
if (wrappedInsn.getType() == InsnType.CONST_STR) {
return "(\"" + ((ConstStringNode) wrappedInsn).getString() + "\")";
}
return "(wrap:" + type + ":" + wrappedInsn.getType() + ')';
return "(wrap " + type + ":" + wrappedInsn.getType() + ')';
}
@Override
@@ -92,6 +92,6 @@ public final class InsnWrapArg extends InsnArg {
if (wrappedInsn.getType() == InsnType.CONST_STR) {
return "(\"" + ((ConstStringNode) wrappedInsn).getString() + "\")";
}
return "(wrap:" + type + ":" + wrappedInsn + ')';
return "(wrap " + type + ":" + wrappedInsn + ')';
}
}
@@ -406,6 +406,14 @@ public class InsnNode extends LineAttrNode {
&& Objects.equals(arguments, other.arguments);
}
@SuppressWarnings("unchecked")
public static <T extends InsnArg> @Nullable T duplicateArg(@Nullable T arg) {
if (arg == null) {
return null;
}
return (T) arg.duplicate();
}
protected final <T extends InsnNode> T copyCommonParams(T copy) {
if (copy.getArgsCount() == 0) {
for (InsnArg arg : this.getArguments()) {
@@ -85,7 +85,6 @@ public class SimplifyVisitor extends AbstractVisitor {
int insnCount = list.size();
InsnNode modInsn = simplifyInsn(mth, insn, null);
if (modInsn != null) {
modInsn.rebindArgs();
if (i < list.size() && list.get(i) == insn) {
list.set(i, modInsn);
} else {
@@ -95,6 +94,8 @@ public class SimplifyVisitor extends AbstractVisitor {
}
list.set(idx, modInsn);
}
InsnRemover.unbindInsn(mth, insn);
modInsn.rebindArgs();
if (list.size() < insnCount) {
// some insns removed => restart block processing
simplifyBlock(mth, block);
@@ -239,8 +240,8 @@ public class SimplifyVisitor extends AbstractVisitor {
|| shadowedByOuterCast(mth.root(), castToType, parentInsn)) {
InsnNode insnNode = new InsnNode(InsnType.MOVE, 1);
insnNode.setOffset(castInsn.getOffset());
insnNode.setResult(castInsn.getResult());
insnNode.addArg(castArg);
insnNode.setResult(InsnNode.duplicateArg(castInsn.getResult()));
insnNode.addArg(castArg.duplicate());
return insnNode;
}
return null;
@@ -576,7 +577,11 @@ public class SimplifyVisitor extends AbstractVisitor {
if (litArg.isNegative()) {
LiteralArg negLitArg = litArg.negate();
if (negLitArg != null) {
return new ArithNode(ArithOp.SUB, arith.getResult(), arith.getArg(0), negLitArg);
RegisterArg resArg = InsnNode.duplicateArg(arith.getResult());
ArithNode newInsn = new ArithNode(ArithOp.SUB, resArg, arith.getArg(0).duplicate(), negLitArg);
newInsn.copyAttributesFrom(arith);
newInsn.setOffset(arith.getOffset());
return newInsn;
}
}
break;
@@ -586,10 +591,12 @@ public class SimplifyVisitor extends AbstractVisitor {
InsnArg firstArg = arith.getArg(0);
long lit = litArg.getLiteral();
if (firstArg.getType() == ArgType.BOOLEAN && (lit == 0 || lit == 1)) {
InsnNode node = new InsnNode(lit == 0 ? InsnType.MOVE : InsnType.NOT, 1);
node.setResult(arith.getResult());
node.addArg(firstArg);
return node;
InsnNode newInsn = new InsnNode(lit == 0 ? InsnType.MOVE : InsnType.NOT, 1);
newInsn.setResult(InsnNode.duplicateArg(arith.getResult()));
newInsn.addArg(firstArg.duplicate());
newInsn.copyAttributesFrom(arith);
newInsn.setOffset(arith.getOffset());
return newInsn;
}
break;
}
@@ -637,16 +644,22 @@ public class SimplifyVisitor extends AbstractVisitor {
}
if (wrapType == InsnType.ARITH) {
ArithNode ar = (ArithNode) wrap;
return ArithNode.oneArgOp(ar.getOp(), fArg, ar.getArg(1));
ArithNode newInsn = ArithNode.oneArgOp(ar.getOp(), fArg, ar.getArg(1).duplicate());
newInsn.copyAttributesFrom(insn);
newInsn.setOffset(insn.getOffset());
return newInsn;
}
int argsCount = wrap.getArgsCount();
InsnNode concat = new InsnNode(InsnType.STR_CONCAT, argsCount - 1);
for (int i = 1; i < argsCount; i++) {
concat.addArg(wrap.getArg(i));
concat.addArg(wrap.getArg(i).duplicate());
}
InsnArg concatArg = InsnArg.wrapArg(concat);
concatArg.setType(ArgType.STRING);
return ArithNode.oneArgOp(ArithOp.ADD, fArg, concatArg);
ArithNode newInsn = ArithNode.oneArgOp(ArithOp.ADD, fArg, concatArg);
newInsn.copyAttributesFrom(wrap);
newInsn.setOffset(wrap.getOffset());
return newInsn;
} catch (Exception e) {
LOG.debug("Can't convert field arith insn: {}, mth: {}", insn, mth, e);
}
@@ -111,23 +111,26 @@ public class TernaryMod extends AbstractRegionVisitor implements IRegionIterativ
RegisterArg resArg;
if (thenPhi.getArgsCount() == 2) {
resArg = thenPhi.getResult();
InsnRemover.unbindResult(mth, thenInsn);
} else {
resArg = thenResArg;
thenPhi.removeArg(elseResArg);
}
InsnArg thenArg = InsnArg.wrapInsnIntoArg(thenInsn);
InsnArg elseArg = InsnArg.wrapInsnIntoArg(elseInsn);
TernaryInsn ternInsn = new TernaryInsn(ifRegion.getCondition(), resArg, thenArg, elseArg);
InsnArg thenArg = InsnArg.wrapInsnIntoArg(thenInsn.copyWithoutResult());
InsnArg elseArg = InsnArg.wrapInsnIntoArg(elseInsn.copyWithoutResult());
TernaryInsn ternInsn = new TernaryInsn(ifRegion.getCondition(), resArg.duplicate(), thenArg, elseArg);
int branchLine = Math.max(thenInsn.getSourceLine(), elseInsn.getSourceLine());
ternInsn.setSourceLine(Math.max(ifRegion.getSourceLine(), branchLine));
thenInsn.setResult(null); // unset without unbind, SSA var still in use
InsnRemover.unbindResult(mth, elseInsn);
InsnRemover.unbindInsn(mth, thenInsn);
InsnRemover.unbindInsn(mth, elseInsn);
ternInsn.rebindArgs();
if (thenPhi.getArgsCount() == 0) {
InsnRemover.unbindResult(mth, thenPhi);
InsnRemover.delistPhi(mth, thenPhi);
}
// remove 'if' instruction
header.getInstructions().clear();
ternInsn.rebindArgs();
header.getInstructions().add(ternInsn);
clearConditionBlocks(conditionBlocks, header);
@@ -321,11 +324,11 @@ public class TernaryMod extends AbstractRegionVisitor implements IRegionIterativ
InsnArg elseArg;
if (elseAssign != null && elseAssign.isConstInsn()) {
// inline constant
elseArg = InsnArg.wrapInsnIntoArg(elseAssign.copyWithoutResult());
SSAVar elseVar = elseAssign.getResult().getSVar();
if (elseVar.getUseCount() == 1 && elseVar.getOnlyOneUseInPhi() == phiInsn) {
InsnRemover.remove(mth, elseAssign);
}
elseArg = InsnArg.wrapInsnIntoArg(elseAssign);
} else {
elseArg = otherArg.duplicate();
}
@@ -78,8 +78,13 @@ public class DebugChecks {
for (InsnArg arg : insn.getArguments()) {
if (arg instanceof RegisterArg) {
checkVar(mth, insn, (RegisterArg) arg);
} else if (arg.isInsnWrap()) {
} else if (arg instanceof InsnWrapArg) {
InsnNode wrapInsn = ((InsnWrapArg) arg).getWrapInsn();
if (wrapInsn.contains(AFlag.DONT_GENERATE)
&& !insn.contains(AFlag.DONT_GENERATE)
&& !mth.contains(AFlag.DONT_GENERATE)) {
throw new JadxRuntimeException("Not generated wrapped insn: \n " + wrapInsn + ",\nouter insn:\n " + insn);
}
checkInsn(mth, block, wrapInsn);
}
}
@@ -3,10 +3,12 @@ package jadx.tests.integration.arith;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.utils.assertj.JadxAssertions;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestFieldIncrement extends IntegrationTest {
@SuppressWarnings("unused")
public static class TestCls {
public int instanceField = 1;
public static int staticField = 1;
@@ -27,7 +29,7 @@ public class TestFieldIncrement extends IntegrationTest {
@Test
public void test() {
JadxAssertions.assertThat(getClassNode(TestCls.class))
assertThat(getClassNode(TestCls.class))
.code()
.contains("instanceField++;")
.contains("staticField--;")
@@ -5,10 +5,12 @@ import java.util.Random;
import org.junit.jupiter.api.Test;
import jadx.tests.api.IntegrationTest;
import jadx.tests.api.utils.assertj.JadxAssertions;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestFieldIncrement3 extends IntegrationTest {
@SuppressWarnings("SpellCheckingInspection")
public static class TestCls {
static int tileX;
static int tileY;
@@ -21,20 +23,20 @@ public class TestFieldIncrement3 extends IntegrationTest {
int direction = rd.nextInt(7);
switch (direction) {
case 0:
targetPos.x = ((tileX + 1) * 55) + 55;
targetPos.y = ((tileY + 1) * 35) + 35;
targetPos.x = (tileX + 1) * 55 + 55;
targetPos.y = (tileY + 1) * 35 + 35;
break;
case 2:
targetPos.x = ((tileX + 1) * 55) + 55;
targetPos.y = ((tileY - 1) * 35) + 35;
targetPos.x = (tileX + 1) * 55 + 55;
targetPos.y = (tileY - 1) * 35 + 35;
break;
case 4:
targetPos.x = ((tileX - 1) * 55) + 55;
targetPos.y = ((tileY - 1) * 35) + 35;
targetPos.x = (tileX - 1) * 55 + 55;
targetPos.y = (tileY - 1) * 35 + 35;
break;
case 6:
targetPos.x = ((tileX - 1) * 55) + 55;
targetPos.y = ((tileY + 1) * 35) + 35;
targetPos.x = (tileX - 1) * 55 + 55;
targetPos.y = (tileY + 1) * 35 + 35;
break;
default:
break;
@@ -42,7 +44,7 @@ public class TestFieldIncrement3 extends IntegrationTest {
directVect.x = targetPos.x - newPos.x;
directVect.y = targetPos.y - newPos.y;
float hPos = (float) Math.sqrt((directVect.x * directVect.x) + (directVect.y * directVect.y));
float hPos = (float) Math.sqrt(directVect.x * directVect.x + directVect.y * directVect.y);
directVect.x /= hPos;
directVect.y /= hPos;
}
@@ -57,14 +59,14 @@ public class TestFieldIncrement3 extends IntegrationTest {
}
public boolean equals(Vector2 other) {
return (this.x == other.x && this.y == other.y);
return this.x == other.x && this.y == other.y;
}
}
}
@Test
public void test() {
JadxAssertions.assertThat(getClassNode(TestCls.class))
assertThat(getClassNode(TestCls.class))
.code()
.contains("directVect.x = targetPos.x - newPos.x;");
}
@@ -0,0 +1,17 @@
package jadx.tests.integration.variables;
import org.junit.jupiter.api.Test;
import jadx.tests.api.SmaliTest;
import static jadx.tests.api.utils.assertj.JadxAssertions.assertThat;
public class TestVariables7 extends SmaliTest {
@Test
public void testNoDebug() {
getArgs().setDebugInfo(false);
assertThat(getClassNodeFromSmali())
.code()
.doesNotContain("r0");
}
}
@@ -0,0 +1,57 @@
.class Lvariables/TestVariables7;
.super Ljava/lang/Object;
.method test([BII)I
.registers 12
.param p1, "a" # [B
.param p2, "b" # I
.param p3, "c" # I
.prologue
.line 290
invoke-static {p3}, Ljava/lang/Integer;->toOctalString(I)Ljava/lang/String;
move-result-object v3
.line 291
.local v3, "oct":Ljava/lang/String;
invoke-virtual {v3}, Ljava/lang/String;->length()I
move-result v2
.line 292
.local v2, "len":I
sub-int v4, p2, v2
.line 293
.local v4, "off":I
const/4 v5, 0x0
.line 294
.local v5, "sum":I
const/4 v1, 0x0
.local v1, "j":I
:goto_c
if-lt v1, v2, :cond_f
.line 299
return v5
.line 295
:cond_f
invoke-virtual {v3, v1}, Ljava/lang/String;->charAt(I)C
move-result v0
.line 296
.local v0, "ch":C
and-int/lit16 v6, v0, 0xff
add-int/lit8 v6, v6, -0x30
add-int/2addr v5, v6
.line 297
add-int v6, v4, v1
int-to-byte v7, v0
aput-byte v7, p1, v6
.line 294
add-int/lit8 v1, v1, 0x1
goto :goto_c
.end method