From ed320233c3b6de2c79206d1a270b822036abd6ea Mon Sep 17 00:00:00 2001
From: Anh Nguyen-Tuong <zenpoems@gmail.com>
Date: Mon, 18 Mar 2019 17:37:17 -0400
Subject: [PATCH] Working laf-intel + dom graph combo

---
 tools/laf/laf.cpp                         | 112 +++++++++++++++-------
 tools/laf/test/{test5.c => test_cmp_32.c} |  15 ++-
 tools/laf/test/test_cmp_64.c              |  30 ++++++
 tools/zax/zax_base.cpp                    |   3 +-
 zipr_umbrella                             |   2 +-
 5 files changed, 124 insertions(+), 38 deletions(-)
 rename tools/laf/test/{test5.c => test_cmp_32.c} (63%)
 create mode 100644 tools/laf/test/test_cmp_64.c

diff --git a/tools/laf/laf.cpp b/tools/laf/laf.cpp
index 2327cdf..f748b8b 100644
--- a/tools/laf/laf.cpp
+++ b/tools/laf/laf.cpp
@@ -255,20 +255,35 @@ int Laf_t::doTraceCompare()
 			{
 				if (traceBytes2(c, d.getImmediate()))
 				{
-					if (m_verbose) cout << "success for " << s << endl;
+					if (m_verbose) 
+					{
+						cout << "success for " << s << endl;
+					}
 					m_num_cmp_instrumented++;
 				}
 			}
 			else if (traceBytes48(c, d.getOperand(0)->getArgumentSizeInBytes(), d.getImmediate()))
 			{
-				if (m_verbose) cout << "success for " << s << endl;
+				if (m_verbose)
+				{
+					cout << "success for " << s << endl;
+				}
 				m_num_cmp_instrumented++;
 			}
 		}
 
+		if (m_verbose) 
+		{
+ 			getFileIR()->assembleRegistry();
+		 	getFileIR()->setBaseIDS();
+			cout << "Post transformation CFG for " << func->getName() << ":" << endl;
+			auto post_cfg=ControlFlowGraph_t::factory(func);	
+			cout << *post_cfg << endl;
+		}
 	};
 
 	return 1;	 // true means success
+
 }
 
 int Laf_t::doTraceDiv()
@@ -445,24 +460,27 @@ bool Laf_t::traceBytes2(Instruction_t *p_instr, const uint32_t p_immediate)
 // p_reg is a free register
 // p_num_bytes has value 4 or 8
 // 
-//     t = reg[lodword]           ; t is a register
-//     m = k[lodword]             ; m is memory where we stashed the constant
+//     t = reg[0..3]          ; t is a register
+//     m = k[0..3]            ; m is memory where we stashed the constant
 //     cmp t, dword [m]          ; elide if 4 byte compare
 //  +- je check_upper            ; elide if 4 byte compare
-//  |  cmp t, dword [m]  <--+    ; loop_back
-//  |  je orig              |
-//  |  t >> 8               |
-//  |  m >> 8               |
-//  |  jmp -----------------+
-//  |
+//  |  
+//  |  cmp t, byte [m]   <---+    ; loop_back
+//  |  jne orig              |
+//  |  t >> 8                |
+//  |  m >> 8                |
+//  |  jmp ------------------+
+//  |  and t, 0x00ffffff        ; clear 4th byte
+//  |  mov m[3], 0xff           ; clear 4th byte
 //  check_upper:                ; only if 8 byte compare
-//     t1 = reg[hidword]
-//     m = k[hidword]
-//     cmp t, dword [m]  <--+
-//     je orig              |
-//     t >> 8               |
-//     m >> 8               |
-//     jmp -----------------+
+//     t = reg[4..7]
+//     and t, 0x00ffffff        ; clear 7th byte
+//     mov m[7], 0xff           ; clear 7th byte
+//     cmp t, dword [m+4] <--+
+//     jne orig              |
+//     t >> 8                |
+//     m >> 8                |
+//     jmp ------------------+
 //
 // orig:
 //     cmp reg, K or cmp [], K
@@ -518,6 +536,7 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 		save_tmp = getFreeRegister(p_instr, free_reg8, RegisterSet_t({rn_RBX, rn_RCX, rn_RDI, rn_RSI, rn_R8, rn_R9, rn_R10, rn_R11, rn_R12, rn_R13, rn_R14, rn_R15}));
 	else
 		save_tmp = getFreeRegister(p_instr, free_reg8, RegisterSet_t({rn_RAX, rn_RBX, rn_RCX, rn_RDX, rn_RDI, rn_RSI, rn_R8, rn_R9, rn_R10, rn_R11, rn_R12, rn_R13, rn_R14, rn_R15}));
+	const auto free_reg1 = registerToString(convertRegisterTo8bit(Register::getRegister(free_reg8)));
 	const auto free_reg4 = registerToString(convertRegisterTo32bit(Register::getRegister(free_reg8)));
 	if(free_reg8.empty()) throw;
 	
@@ -536,6 +555,7 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 		cout << "save tmp: " << s << endl;
 	}
 
+	// copy value into free register
 	if (d.getOperand(0)->isRegister())
 	{
 		auto source_reg = d.getOperand(0)->getString();
@@ -566,12 +586,24 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 		cout << s << endl;
 	}
 
+	// make sure to terminate by using 0x00 and 0xFF in upper byte
+	//
+	// clear 4th byte of value
+	s = "and " + free_reg4 + ", 0x00FFFFFF";
+	t = insertAssemblyAfter(t, s);
+	cout << "clear byte4 of val: " << s << endl;
+	
+	// set 4th byte of constant
+	s = "mov byte ["  + mem + "+3], 0xFF";
+	t = insertAssemblyAfter(t, s);
+	cout << "clear byte4 of K  : " << s << endl;
+
 	// loop_back
-	s = "cmp " + free_reg4 + ", dword [" + mem + "]";
+	s = "cmp " + free_reg1 + ", byte [" + mem + "]";
 	const auto loop_back = t = insertAssemblyAfter(t, s);
 	cout << s << endl;
 
-	s = "je 0"; // orig
+	s = "jne 0"; // orig
 	t = insertAssemblyAfter(t, s);
 	t->setTarget(traced_instr);
 	cout << s << endl;
@@ -584,8 +616,8 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 	t = insertAssemblyAfter(t, s);
 	cout << s << endl;
 
-	s = "jmp 0"; // loop_back
-	t = insertAssemblyAfter(t, "jmp 0"); // loop_back
+	s = "jmp 0"; // jump to loop_back
+	t = insertAssemblyAfter(t, s); 
 	t->setTarget(loop_back);
 	t->setFallthrough(nullptr);
 	cout << s << endl;
@@ -606,16 +638,17 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 	//
 	// mem is memory location representing constant (hidword)
 	// p_reg is register representing the value to be checked (hidword)
+	// 
+	//  check_upper:                ; only if 8 byte compare
+	//     t = reg[4..7]
+	//     and t, 0x00ffffff        ; clear 7th byte
+	//     mov m[7], 0xff           ; clear 7th byte
+	//     cmp t, dword [m+4] <--+
+	//     jne orig              |
+	//     t >> 8                |
+	//     m >> 8                |
+	//     jmp ------------------+
 	//
-//  check_upper: 
-//     t1 = reg[hidword]
-//     m = k[hidword]
-//     cmp t, dword [m]  <--+
-//     je orig              |
-//     t >> 8               |
-//     m >> 8               |
-//     jmp -----------------+
-//
 
 	if (d.getOperand(0)->isRegister())
 	{
@@ -641,12 +674,24 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 		cout << s << endl;
 	}
 
+	// make sure to terminate by using 0x00 and 0xFF in upper byte
+	//
+	// clear 4th byte of value (7th byte of original)
+	s = "and " + free_reg4 + ", 0x00FFFFFF";
+	t = insertAssemblyAfter(t, s);
+	cout << "clear byte7 of orig val: " << s << endl;
+	
+	// set 7th byte of constant
+	s = "mov byte ["  + mem + "+7], 0xFF";
+	t = insertAssemblyAfter(t, s);
+	cout << "clear byte7 of K  : " << s << endl;
 
-	s = "cmp " + free_reg4 + ", dword [" + mem + "+4]"; // loop_back2
+	// loop_back2
+	s = "cmp " + free_reg4 + ", dword [" + mem + "+4]"; 
 	auto loop_back2 = t = insertAssemblyAfter(t, s);
 	cout << s << endl;
 
-	s = "je 0"; // orig
+	s = "jne 0"; // orig
 	t = insertAssemblyAfter(t, s);
 	t->setTarget(traced_instr);
 	cout << s << endl;
@@ -659,9 +704,10 @@ bool Laf_t::traceBytes48(Instruction_t *p_instr, size_t p_num_bytes, uint64_t p_
 	t = insertAssemblyAfter(t, s);
 	cout << s << endl;
 
-	s = "jmp 0"; // loop_back2
+	s = "jmp 0"; // jmp to loop_back2
 	t = insertAssemblyAfter(t, s);
 	t->setTarget(loop_back2);
+	t->setFallthrough(nullptr);
 	cout << s << endl;
 
 	return true;
diff --git a/tools/laf/test/test5.c b/tools/laf/test/test_cmp_32.c
similarity index 63%
rename from tools/laf/test/test5.c
rename to tools/laf/test/test_cmp_32.c
index cb06cc6..bd0a574 100644
--- a/tools/laf/test/test5.c
+++ b/tools/laf/test/test_cmp_32.c
@@ -1,16 +1,25 @@
-#include <iostream>
 #include <stdio.h>
 #include <stdlib.h>
 
+volatile int compare_me(int x)
+{
+	if (x == 0x12345678)
+            abort();
+}
+
 int main(int argc, char **argv)
 {
 	int x;
 	FILE *fp = fopen(argv[1],"r");
 
+	if (!fp) {
+		fprintf(stderr, "Need input file\n");
+		return 1;
+	}
+
 	fread(&x, 4, 1, fp);
 
-	if (x == 0x12345678)
-            abort();
+	compare_me(x);
 
 	fclose(fp);
 	return 0;
diff --git a/tools/laf/test/test_cmp_64.c b/tools/laf/test/test_cmp_64.c
new file mode 100644
index 0000000..2b3f4e1
--- /dev/null
+++ b/tools/laf/test/test_cmp_64.c
@@ -0,0 +1,30 @@
+#include <stdio.h>
+#include <stdlib.h>
+
+volatile int compare_me(long x)
+{
+	if (x == -0x12345678L)
+            abort();
+}
+
+int main(int argc, char **argv)
+{
+	long x;
+	int y;
+	FILE *fp = fopen(argv[1],"r");
+
+	if (!fp) {
+		fprintf(stderr, "Need input file\n");
+		return 1;
+	}
+
+	printf("sizeof(x)=%lu\n", sizeof(x));
+
+	fread(&x, 8, 1, fp);
+
+	compare_me(x);
+
+	fclose(fp);
+	return 0;
+
+}
diff --git a/tools/zax/zax_base.cpp b/tools/zax/zax_base.cpp
index 6c84e54..057b5de 100644
--- a/tools/zax/zax_base.cpp
+++ b/tools/zax/zax_base.cpp
@@ -1189,7 +1189,8 @@ int ZaxBase_t::execute()
 		const auto num_blocks_in_func = cfg.getBlocks().size();
 		m_num_bb += num_blocks_in_func;
 
-		if (m_graph_optimize && num_blocks_in_func == 1)
+		// skip single-block functions that are not indirectly called
+		if (num_blocks_in_func == 1 && !f->getEntryPoint()->getIndirectBranchTargetAddress())
 		{
 			m_num_single_block_function_elided++;
 			m_num_bb_skipped++;
diff --git a/zipr_umbrella b/zipr_umbrella
index 6461f33..de33315 160000
--- a/zipr_umbrella
+++ b/zipr_umbrella
@@ -1 +1 @@
-Subproject commit 6461f337f9097810d7cf14052a83c8376e823fde
+Subproject commit de33315f785f73df05368ae9ebcb65483c73d7d9
-- 
GitLab