From 77da3cd45e814652de3540da25273a396b991072 Mon Sep 17 00:00:00 2001
From: an7s <an7s@git.zephyr-software.com>
Date: Wed, 9 Mar 2016 18:02:16 +0000
Subject: [PATCH] Keep track of STARS SAFE_FUNC in IRDB

Former-commit-id: 1eb2c568c96a19ce73a5ec41e625bcc069ea9e6c
---
 libIRDB/include/core/function.hpp   |  6 ++++-
 libIRDB/src/core/fileir.cpp         | 17 ++++++++++----
 libIRDB/src/core/function.cpp       |  6 +++--
 tools/meds2pdb/meds2pdb.cpp         |  6 +++--
 tools/selective_cfi/scfi_driver.cpp | 16 +++++++++++--
 tools/selective_cfi/scfi_instr.cpp  | 35 ++++++++++++++++++++++++-----
 tools/selective_cfi/scfi_instr.hpp  |  6 ++++-
 7 files changed, 75 insertions(+), 17 deletions(-)

diff --git a/libIRDB/include/core/function.hpp b/libIRDB/include/core/function.hpp
index 49f123844..85ef858c5 100644
--- a/libIRDB/include/core/function.hpp
+++ b/libIRDB/include/core/function.hpp
@@ -28,7 +28,7 @@ class Function_t : public BaseObj_t
 	Function_t() : BaseObj_t(NULL) {}	// create a new function not in the db 
 
 	// create a function that's already in the DB  
-	Function_t(db_id_t id, std::string name, int size, int oa_size, bool use_fp, FuncType_t *, Instruction_t *entry);	
+	Function_t(db_id_t id, std::string name, int size, int oa_size, bool use_fp, bool is_safe, FuncType_t *, Instruction_t *entry);	
 
 	InstructionSet_t& GetInstructions() { return my_insns; }
 
@@ -48,6 +48,9 @@ class Function_t : public BaseObj_t
         bool GetUseFramePointer() const { return use_fp; }
         void SetUseFramePointer(bool useFP) { use_fp = useFP; }
 
+        void SetSafe(bool safe) { is_safe = safe; }
+        bool IsSafe() const { return is_safe; }
+
 	void SetType(FuncType_t *t) { function_type = t; }
 	FuncType_t* GetType() const { return function_type; }
 
@@ -60,6 +63,7 @@ class Function_t : public BaseObj_t
         std::string name;
         int out_args_region_size;
         bool use_fp;
+        bool is_safe;
 	FuncType_t *function_type;
 };
 
diff --git a/libIRDB/src/core/fileir.cpp b/libIRDB/src/core/fileir.cpp
index 8e7c3860f..ded349ff7 100644
--- a/libIRDB/src/core/fileir.cpp
+++ b/libIRDB/src/core/fileir.cpp
@@ -284,7 +284,7 @@ std::map<db_id_t,Function_t*> FileIR_t::ReadFuncsFromDB
 
 	while(!dbintr->IsDone())
 	{
-// function_id | file_id | name | stack_frame_size | out_args_region_size | use_frame_pointer | doip_id
+// function_id | file_id | name | stack_frame_size | out_args_region_size | use_frame_pointer | is_safe | doip_id
 
 		db_id_t fid=atoi(dbintr->GetResultColumn("function_id").c_str());
 		db_id_t entry_point_id=atoi(dbintr->GetResultColumn("entry_point_id").c_str());
@@ -293,21 +293,30 @@ std::map<db_id_t,Function_t*> FileIR_t::ReadFuncsFromDB
 		int oasize=atoi(dbintr->GetResultColumn("out_args_region_size").c_str());
 		db_id_t function_type_id=atoi(dbintr->GetResultColumn("type_id").c_str());
 // postgresql encoding of boolean can be 'true', '1', 'T', 'y'
-                bool useFP=false;
+		bool useFP=false;
+		bool isSafe=false;
 		string useFPString=dbintr->GetResultColumn("use_frame_pointer"); 
+		string isSafeString=dbintr->GetResultColumn("is_safe"); 
 		const char *useFPstr=useFPString.c_str();
-                if (strlen(useFPstr) > 0)
+		const char *isSafestr=isSafeString.c_str();
+		if (strlen(useFPstr) > 0)
 		{
 			if (useFPstr[0] == 't' || useFPstr[0] == 'T' || useFPstr[0] == '1' || useFPstr[0] == 'y' || useFPstr[0] == 'Y')
 				useFP = true;
 		}
 
+		if (strlen(isSafestr) > 0)
+		{
+			if (isSafestr[0] == 't' || isSafestr[0] == 'T' || isSafestr[0] == '1' || isSafestr[0] == 'y' || isSafestr[0] == 'Y')
+				isSafe = true;
+		}
+
 		db_id_t doipid=atoi(dbintr->GetResultColumn("doip_id").c_str());
 
 		FuncType_t* fnType = NULL;
 		if (typesMap.count(function_type_id) > 0)
 			fnType = dynamic_cast<FuncType_t*>(typesMap[function_type_id]);
-		Function_t *newfunc=new Function_t(fid,name,sfsize,oasize,useFP,fnType, NULL); 
+		Function_t *newfunc=new Function_t(fid,name,sfsize,oasize,useFP,isSafe,fnType, NULL); 
 		entry_points[newfunc]=entry_point_id;
 		
 //std::cout<<"Found function "<<name<<"."<<std::endl;
diff --git a/libIRDB/src/core/function.cpp b/libIRDB/src/core/function.cpp
index dcad3fc58..4167c1b21 100644
--- a/libIRDB/src/core/function.cpp
+++ b/libIRDB/src/core/function.cpp
@@ -25,7 +25,7 @@
 using namespace libIRDB;
 using namespace std;
 
-Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bool useFP, FuncType_t *fn_type, Instruction_t* entry)
+Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bool useFP, bool isSafe, FuncType_t *fn_type, Instruction_t* entry)
 	: BaseObj_t(NULL), entry_point(entry)
 {
 	SetBaseID(id);
@@ -33,6 +33,7 @@ Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bo
 	stack_frame_size=size;	
 	out_args_region_size=oa_size;
     use_fp = useFP;
+    SetSafe(isSafe);
 	function_type = fn_type;
 }
 
@@ -52,7 +53,7 @@ string Function_t::WriteToDB(File_t *fid, db_id_t newid)
 		function_type_id = GetType()->GetBaseID();	 
 
 	string q=string("insert into ")+fid->function_table_name + 
-		string(" (function_id, entry_point_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, type_id, doip_id) ")+
+		string(" (function_id, entry_point_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, is_safe, type_id, doip_id) ")+
 		string(" VALUES (") + 
 		string("'") + to_string(GetBaseID()) 		  + string("', ") + 
 		string("'") + to_string(entryid) 		  + string("', ") + 
@@ -60,6 +61,7 @@ string Function_t::WriteToDB(File_t *fid, db_id_t newid)
 		string("'") + to_string(stack_frame_size) 	  + string("', ") + 
 	        string("'") + to_string(out_args_region_size) 	  + string("', ") + 
 	        string("'") + to_string(use_fp) 		  + string("', ") + 
+	        string("'") + to_string(is_safe) 		  + string("', ") + 
 	        string("'") + to_string(function_type_id) 	  + string("', ") + 
 		string("'") + to_string(GetDoipID()) 		  + string("') ; ") ;
 
diff --git a/tools/meds2pdb/meds2pdb.cpp b/tools/meds2pdb/meds2pdb.cpp
index a01ae7e41..2075bb597 100644
--- a/tools/meds2pdb/meds2pdb.cpp
+++ b/tools/meds2pdb/meds2pdb.cpp
@@ -189,7 +189,7 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions  )
   for (int i = 0; i < functions.size(); i += STRIDE)
   {  
     string query = "INSERT INTO " + functionTable;
-    query += " (function_id, name, stack_frame_size, out_args_region_size, use_frame_pointer) VALUES ";
+    query += " (function_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, is_safe) VALUES ";
 
 
     for (int j = i; j < i + STRIDE; ++j)
@@ -205,6 +205,7 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions  )
 
       int outArgsRegionSize = f->getOutArgsRegionSize();
       bool useFP = f->getUseFramePointer();
+      bool isSafe = f->isSafe();
 
       if (j != i) query += ",";
       query += "(";
@@ -212,7 +213,8 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions  )
       query += txn.quote(functionName) + ",";
       query += txn.quote(functionFrameSize) + ",";
       query += txn.quote(outArgsRegionSize) + ",";
-      query += txn.quote(useFP) + ")";
+      query += txn.quote(useFP) + ",";
+      query += txn.quote(isSafe) + ")";
 
     }
 
diff --git a/tools/selective_cfi/scfi_driver.cpp b/tools/selective_cfi/scfi_driver.cpp
index 5fbed5e04..4bdf11adc 100644
--- a/tools/selective_cfi/scfi_driver.cpp
+++ b/tools/selective_cfi/scfi_driver.cpp
@@ -39,9 +39,10 @@ void usage(char* name)
 "		[--color|--no-color]  \n"
 "		[--protect-jumps|--no-protect-jumps]  \n"
 "		[--protect-rets|--no-protect-rets] \n"
+"		[--protect-safefn|--no-protect-safefn]  \n"
 "		[ --common-slow-path | --no-common-slow-path ] \n"
 " \n"
-"default: --no-color --protect-jumps --protect-rets --common-slow-path\n"; 
+"default: --no-color --protect-jumps --protect-rets --no-protect-safefn --common-slow-path\n"; 
 }
 
 int main(int argc, char **argv)
@@ -62,6 +63,7 @@ int main(int argc, char **argv)
 	bool do_common_slow_path=true;
 	bool do_jumps=true;
 	bool do_rets=true;
+	bool do_safefn=false;
 	for(int  i=2;i<argc;i++)
 	{
 		if(string(argv[i])=="--color")
@@ -94,6 +96,16 @@ int main(int argc, char **argv)
 			cout<<"Not protecting returns..."<<endl;
 			do_rets=false;
 		}
+		else if(string(argv[i])=="--protect-safefn")
+		{
+			cout<<"protecting safe functions..."<<endl;
+			do_safefn=true;
+		}
+		else if(string(argv[i])=="--no-protect-safefn")
+		{
+			cout<<"Not protecting safe functions..."<<endl;
+			do_safefn=false;
+		}
 		else if(string(argv[i])=="--common-slow-path")
 		{
 			cout<<"Using common slow path..."<<endl;
@@ -140,7 +152,7 @@ int main(int argc, char **argv)
 
                 try
                 {
-			SCFI_Instrument scfii(firp, do_coloring, do_common_slow_path, do_jumps, do_rets);
+			SCFI_Instrument scfii(firp, do_coloring, do_common_slow_path, do_jumps, do_rets, do_safefn);
 
 
 			int success=scfii.execute();
diff --git a/tools/selective_cfi/scfi_instr.cpp b/tools/selective_cfi/scfi_instr.cpp
index 3b1fe4391..960286158 100644
--- a/tools/selective_cfi/scfi_instr.cpp
+++ b/tools/selective_cfi/scfi_instr.cpp
@@ -217,7 +217,10 @@ Relocation_t* SCFI_Instrument::FindRelocation(Instruction_t* insn, string type)
         return NULL;
 }
 
-
+bool SCFI_Instrument::isSafeFunction(Instruction_t* insn)
+{
+	return (insn && insn->GetFunction() && insn->GetFunction()->IsSafe());
+}
 
 
 Relocation_t* SCFI_Instrument::create_reloc(Instruction_t* insn)
@@ -602,6 +605,8 @@ bool SCFI_Instrument::instrument_jumps()
 	int cfi_branch_call_complete=0;
 	int cfi_branch_ret_checks=0;
 	int cfi_branch_ret_complete=0;
+	int	cfi_safefn_jmp_skipped=0;
+	int	cfi_safefn_ret_skipped=0;
 	int ibt_complete=0;
 	double cfi_branch_jmp_complete_ratio = NAN;
 	double cfi_branch_ret_complete_ratio = NAN;
@@ -625,6 +630,8 @@ bool SCFI_Instrument::instrument_jumps()
 		if(FindRelocation(insn,"cf::safe"))
 			continue;
 
+		bool safefn = isSafeFunction(insn);
+
 		DISASM d;
 		insn->Disassemble(d);
 
@@ -634,13 +641,20 @@ bool SCFI_Instrument::instrument_jumps()
 			case  JmpType:
 				if((d.Argument1.ArgType&MEMORY_TYPE)==MEMORY_TYPE)
 				{
-					cfi_checks++;
-					cfi_branch_jmp_checks++;
 					if (insn->GetIBTargets() && insn->GetIBTargets()->IsComplete())
 					{
 						cfi_branch_jmp_complete++;
 						jmps[insn->GetIBTargets()->size()]++;
 					}
+
+					if (!do_safefn && safefn)
+					{
+						cfi_safefn_jmp_skipped++;
+						continue;
+					}
+
+					cfi_checks++;
+					cfi_branch_jmp_checks++;
 					AddJumpCFI(insn);
 				}
 				break;
@@ -656,14 +670,22 @@ bool SCFI_Instrument::instrument_jumps()
 					cfi_checks++;
 				}
 				break;
+
 			case  RetType: 
-				cfi_branch_ret_checks++;
 				if (insn->GetIBTargets() && insn->GetIBTargets()->IsComplete())
 				{
 					cfi_branch_ret_complete++;
 					rets[insn->GetIBTargets()->size()]++;
 				}
+
+				if (!do_safefn && safefn)
+				{
+					cfi_safefn_ret_skipped++;
+					continue;
+				}
+
 				cfi_checks++;
+				cfi_branch_ret_checks++;
 				AddReturnCFI(insn);
 				break;
 
@@ -673,7 +695,7 @@ bool SCFI_Instrument::instrument_jumps()
 	}
 	
 	cout<<"# ATTRIBUTE cfi_jmp_checks="<<std::dec<<cfi_branch_jmp_checks<<endl;
-	cout<<"# ATTRIBUTE cfi_jmp_complete="<<std::dec<<cfi_branch_jmp_complete<<endl;
+	cout<<"# ATTRIBUTE cfi_jmp_complete="<<cfi_branch_jmp_complete<<endl;
 
 	display_histogram(cout, "cfi_jmp_complete_histogram", jmps);
 
@@ -705,6 +727,9 @@ bool SCFI_Instrument::instrument_jumps()
 	cout << "# ATTRIBUTE cfi_ret_complete_ratio=" << cfi_branch_ret_complete_ratio << endl;
 	cout << "# ATTRIBUTE cfi_complete_ratio=" << cfi_branch_ret_complete_ratio << endl;
 
+	cout<<"# ATTRIBUTE cfi_safefn_jmp_skipped="<<cfi_safefn_jmp_skipped<<endl;
+	cout<<"# ATTRIBUTE cfi_safefn_ret_skipped="<<cfi_safefn_ret_skipped<<endl;
+
 	return true;
 }
 
diff --git a/tools/selective_cfi/scfi_instr.hpp b/tools/selective_cfi/scfi_instr.hpp
index dfb9e4404..57b659a37 100644
--- a/tools/selective_cfi/scfi_instr.hpp
+++ b/tools/selective_cfi/scfi_instr.hpp
@@ -33,12 +33,14 @@ class SCFI_Instrument
 				bool p_do_coloring=true,
 				bool p_do_common_slow_path=true,
 				bool p_do_jumps=true,
-				bool p_do_rets=true) 
+				bool p_do_rets=true,
+				bool p_do_safefn=true) 
 			: firp(the_firp), 
 			  do_coloring(p_do_coloring), 
 			  do_common_slow_path(p_do_common_slow_path), 
 			  do_jumps(p_do_jumps), 
 			  do_rets(p_do_rets), 
+			  do_safefn(p_do_safefn), 
 			  color_map(NULL) {}
 		bool execute();
 
@@ -52,6 +54,7 @@ class SCFI_Instrument
 		// helper
 		libIRDB::Relocation_t* create_reloc(libIRDB::Instruction_t* insn);
 		libIRDB::Relocation_t* FindRelocation(libIRDB::Instruction_t* insn, std::string type);
+		bool isSafeFunction(libIRDB::Instruction_t* insn);
 
 		// add instrumentation
 		bool add_scfi_instrumentation(libIRDB::Instruction_t* insn);
@@ -76,6 +79,7 @@ class SCFI_Instrument
 		bool do_common_slow_path;
 		bool do_jumps;
 		bool do_rets;
+		bool do_safefn;
 		ColoredInstructionNonces_t *color_map;
 
 
-- 
GitLab