diff --git a/libIRDB/include/core/function.hpp b/libIRDB/include/core/function.hpp index 49f123844ee4f1e3d0d328031e92a2aed433286b..85ef858c5d510aace6bf7e3874e58b19b6020ac3 100644 --- a/libIRDB/include/core/function.hpp +++ b/libIRDB/include/core/function.hpp @@ -28,7 +28,7 @@ class Function_t : public BaseObj_t Function_t() : BaseObj_t(NULL) {} // create a new function not in the db // create a function that's already in the DB - Function_t(db_id_t id, std::string name, int size, int oa_size, bool use_fp, FuncType_t *, Instruction_t *entry); + Function_t(db_id_t id, std::string name, int size, int oa_size, bool use_fp, bool is_safe, FuncType_t *, Instruction_t *entry); InstructionSet_t& GetInstructions() { return my_insns; } @@ -48,6 +48,9 @@ class Function_t : public BaseObj_t bool GetUseFramePointer() const { return use_fp; } void SetUseFramePointer(bool useFP) { use_fp = useFP; } + void SetSafe(bool safe) { is_safe = safe; } + bool IsSafe() const { return is_safe; } + void SetType(FuncType_t *t) { function_type = t; } FuncType_t* GetType() const { return function_type; } @@ -60,6 +63,7 @@ class Function_t : public BaseObj_t std::string name; int out_args_region_size; bool use_fp; + bool is_safe; FuncType_t *function_type; }; diff --git a/libIRDB/src/core/fileir.cpp b/libIRDB/src/core/fileir.cpp index 8e7c3860feae1a92b06d775a6c18af913240f884..ded349ff7a6e19e3f1497e87cf39158eeb69b127 100644 --- a/libIRDB/src/core/fileir.cpp +++ b/libIRDB/src/core/fileir.cpp @@ -284,7 +284,7 @@ std::map<db_id_t,Function_t*> FileIR_t::ReadFuncsFromDB while(!dbintr->IsDone()) { -// function_id | file_id | name | stack_frame_size | out_args_region_size | use_frame_pointer | doip_id +// function_id | file_id | name | stack_frame_size | out_args_region_size | use_frame_pointer | is_safe | doip_id db_id_t fid=atoi(dbintr->GetResultColumn("function_id").c_str()); db_id_t entry_point_id=atoi(dbintr->GetResultColumn("entry_point_id").c_str()); @@ -293,21 +293,30 @@ std::map<db_id_t,Function_t*> FileIR_t::ReadFuncsFromDB int oasize=atoi(dbintr->GetResultColumn("out_args_region_size").c_str()); db_id_t function_type_id=atoi(dbintr->GetResultColumn("type_id").c_str()); // postgresql encoding of boolean can be 'true', '1', 'T', 'y' - bool useFP=false; + bool useFP=false; + bool isSafe=false; string useFPString=dbintr->GetResultColumn("use_frame_pointer"); + string isSafeString=dbintr->GetResultColumn("is_safe"); const char *useFPstr=useFPString.c_str(); - if (strlen(useFPstr) > 0) + const char *isSafestr=isSafeString.c_str(); + if (strlen(useFPstr) > 0) { if (useFPstr[0] == 't' || useFPstr[0] == 'T' || useFPstr[0] == '1' || useFPstr[0] == 'y' || useFPstr[0] == 'Y') useFP = true; } + if (strlen(isSafestr) > 0) + { + if (isSafestr[0] == 't' || isSafestr[0] == 'T' || isSafestr[0] == '1' || isSafestr[0] == 'y' || isSafestr[0] == 'Y') + isSafe = true; + } + db_id_t doipid=atoi(dbintr->GetResultColumn("doip_id").c_str()); FuncType_t* fnType = NULL; if (typesMap.count(function_type_id) > 0) fnType = dynamic_cast<FuncType_t*>(typesMap[function_type_id]); - Function_t *newfunc=new Function_t(fid,name,sfsize,oasize,useFP,fnType, NULL); + Function_t *newfunc=new Function_t(fid,name,sfsize,oasize,useFP,isSafe,fnType, NULL); entry_points[newfunc]=entry_point_id; //std::cout<<"Found function "<<name<<"."<<std::endl; diff --git a/libIRDB/src/core/function.cpp b/libIRDB/src/core/function.cpp index dcad3fc58bef03df7dfcd2ba7b161df869ba01f5..4167c1b214814aeaccd840984b09d1045f9a6e9a 100644 --- a/libIRDB/src/core/function.cpp +++ b/libIRDB/src/core/function.cpp @@ -25,7 +25,7 @@ using namespace libIRDB; using namespace std; -Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bool useFP, FuncType_t *fn_type, Instruction_t* entry) +Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bool useFP, bool isSafe, FuncType_t *fn_type, Instruction_t* entry) : BaseObj_t(NULL), entry_point(entry) { SetBaseID(id); @@ -33,6 +33,7 @@ Function_t::Function_t(db_id_t id, std::string myname, int size, int oa_size, bo stack_frame_size=size; out_args_region_size=oa_size; use_fp = useFP; + SetSafe(isSafe); function_type = fn_type; } @@ -52,7 +53,7 @@ string Function_t::WriteToDB(File_t *fid, db_id_t newid) function_type_id = GetType()->GetBaseID(); string q=string("insert into ")+fid->function_table_name + - string(" (function_id, entry_point_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, type_id, doip_id) ")+ + string(" (function_id, entry_point_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, is_safe, type_id, doip_id) ")+ string(" VALUES (") + string("'") + to_string(GetBaseID()) + string("', ") + string("'") + to_string(entryid) + string("', ") + @@ -60,6 +61,7 @@ string Function_t::WriteToDB(File_t *fid, db_id_t newid) string("'") + to_string(stack_frame_size) + string("', ") + string("'") + to_string(out_args_region_size) + string("', ") + string("'") + to_string(use_fp) + string("', ") + + string("'") + to_string(is_safe) + string("', ") + string("'") + to_string(function_type_id) + string("', ") + string("'") + to_string(GetDoipID()) + string("') ; ") ; diff --git a/tools/meds2pdb/meds2pdb.cpp b/tools/meds2pdb/meds2pdb.cpp index a01ae7e41d3a81fc0a519b4857db85917658ee5a..2075bb59724623edb6f69fb51e1f1190f1800a99 100644 --- a/tools/meds2pdb/meds2pdb.cpp +++ b/tools/meds2pdb/meds2pdb.cpp @@ -189,7 +189,7 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions ) for (int i = 0; i < functions.size(); i += STRIDE) { string query = "INSERT INTO " + functionTable; - query += " (function_id, name, stack_frame_size, out_args_region_size, use_frame_pointer) VALUES "; + query += " (function_id, name, stack_frame_size, out_args_region_size, use_frame_pointer, is_safe) VALUES "; for (int j = i; j < i + STRIDE; ++j) @@ -205,6 +205,7 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions ) int outArgsRegionSize = f->getOutArgsRegionSize(); bool useFP = f->getUseFramePointer(); + bool isSafe = f->isSafe(); if (j != i) query += ","; query += "("; @@ -212,7 +213,8 @@ void insert_functions(int fileID, const vector<wahoo::Function*> &functions ) query += txn.quote(functionName) + ","; query += txn.quote(functionFrameSize) + ","; query += txn.quote(outArgsRegionSize) + ","; - query += txn.quote(useFP) + ")"; + query += txn.quote(useFP) + ","; + query += txn.quote(isSafe) + ")"; } diff --git a/tools/selective_cfi/scfi_driver.cpp b/tools/selective_cfi/scfi_driver.cpp index 5fbed5e04b04c527de2091c4832cd8d421538d0e..4bdf11adc68bd3a5cb975cde52f601b8258ae1f2 100644 --- a/tools/selective_cfi/scfi_driver.cpp +++ b/tools/selective_cfi/scfi_driver.cpp @@ -39,9 +39,10 @@ void usage(char* name) " [--color|--no-color] \n" " [--protect-jumps|--no-protect-jumps] \n" " [--protect-rets|--no-protect-rets] \n" +" [--protect-safefn|--no-protect-safefn] \n" " [ --common-slow-path | --no-common-slow-path ] \n" " \n" -"default: --no-color --protect-jumps --protect-rets --common-slow-path\n"; +"default: --no-color --protect-jumps --protect-rets --no-protect-safefn --common-slow-path\n"; } int main(int argc, char **argv) @@ -62,6 +63,7 @@ int main(int argc, char **argv) bool do_common_slow_path=true; bool do_jumps=true; bool do_rets=true; + bool do_safefn=false; for(int i=2;i<argc;i++) { if(string(argv[i])=="--color") @@ -94,6 +96,16 @@ int main(int argc, char **argv) cout<<"Not protecting returns..."<<endl; do_rets=false; } + else if(string(argv[i])=="--protect-safefn") + { + cout<<"protecting safe functions..."<<endl; + do_safefn=true; + } + else if(string(argv[i])=="--no-protect-safefn") + { + cout<<"Not protecting safe functions..."<<endl; + do_safefn=false; + } else if(string(argv[i])=="--common-slow-path") { cout<<"Using common slow path..."<<endl; @@ -140,7 +152,7 @@ int main(int argc, char **argv) try { - SCFI_Instrument scfii(firp, do_coloring, do_common_slow_path, do_jumps, do_rets); + SCFI_Instrument scfii(firp, do_coloring, do_common_slow_path, do_jumps, do_rets, do_safefn); int success=scfii.execute(); diff --git a/tools/selective_cfi/scfi_instr.cpp b/tools/selective_cfi/scfi_instr.cpp index 3b1fe439122cad27a19ddb65d061235228f47602..9602861585a9546edb88b586e19aeb365877e1ca 100644 --- a/tools/selective_cfi/scfi_instr.cpp +++ b/tools/selective_cfi/scfi_instr.cpp @@ -217,7 +217,10 @@ Relocation_t* SCFI_Instrument::FindRelocation(Instruction_t* insn, string type) return NULL; } - +bool SCFI_Instrument::isSafeFunction(Instruction_t* insn) +{ + return (insn && insn->GetFunction() && insn->GetFunction()->IsSafe()); +} Relocation_t* SCFI_Instrument::create_reloc(Instruction_t* insn) @@ -602,6 +605,8 @@ bool SCFI_Instrument::instrument_jumps() int cfi_branch_call_complete=0; int cfi_branch_ret_checks=0; int cfi_branch_ret_complete=0; + int cfi_safefn_jmp_skipped=0; + int cfi_safefn_ret_skipped=0; int ibt_complete=0; double cfi_branch_jmp_complete_ratio = NAN; double cfi_branch_ret_complete_ratio = NAN; @@ -625,6 +630,8 @@ bool SCFI_Instrument::instrument_jumps() if(FindRelocation(insn,"cf::safe")) continue; + bool safefn = isSafeFunction(insn); + DISASM d; insn->Disassemble(d); @@ -634,13 +641,20 @@ bool SCFI_Instrument::instrument_jumps() case JmpType: if((d.Argument1.ArgType&MEMORY_TYPE)==MEMORY_TYPE) { - cfi_checks++; - cfi_branch_jmp_checks++; if (insn->GetIBTargets() && insn->GetIBTargets()->IsComplete()) { cfi_branch_jmp_complete++; jmps[insn->GetIBTargets()->size()]++; } + + if (!do_safefn && safefn) + { + cfi_safefn_jmp_skipped++; + continue; + } + + cfi_checks++; + cfi_branch_jmp_checks++; AddJumpCFI(insn); } break; @@ -656,14 +670,22 @@ bool SCFI_Instrument::instrument_jumps() cfi_checks++; } break; + case RetType: - cfi_branch_ret_checks++; if (insn->GetIBTargets() && insn->GetIBTargets()->IsComplete()) { cfi_branch_ret_complete++; rets[insn->GetIBTargets()->size()]++; } + + if (!do_safefn && safefn) + { + cfi_safefn_ret_skipped++; + continue; + } + cfi_checks++; + cfi_branch_ret_checks++; AddReturnCFI(insn); break; @@ -673,7 +695,7 @@ bool SCFI_Instrument::instrument_jumps() } cout<<"# ATTRIBUTE cfi_jmp_checks="<<std::dec<<cfi_branch_jmp_checks<<endl; - cout<<"# ATTRIBUTE cfi_jmp_complete="<<std::dec<<cfi_branch_jmp_complete<<endl; + cout<<"# ATTRIBUTE cfi_jmp_complete="<<cfi_branch_jmp_complete<<endl; display_histogram(cout, "cfi_jmp_complete_histogram", jmps); @@ -705,6 +727,9 @@ bool SCFI_Instrument::instrument_jumps() cout << "# ATTRIBUTE cfi_ret_complete_ratio=" << cfi_branch_ret_complete_ratio << endl; cout << "# ATTRIBUTE cfi_complete_ratio=" << cfi_branch_ret_complete_ratio << endl; + cout<<"# ATTRIBUTE cfi_safefn_jmp_skipped="<<cfi_safefn_jmp_skipped<<endl; + cout<<"# ATTRIBUTE cfi_safefn_ret_skipped="<<cfi_safefn_ret_skipped<<endl; + return true; } diff --git a/tools/selective_cfi/scfi_instr.hpp b/tools/selective_cfi/scfi_instr.hpp index dfb9e44047876a141effaf5e6fa05ec0a91f622d..57b659a37867d5742d061c9d07ef33cf84bef068 100644 --- a/tools/selective_cfi/scfi_instr.hpp +++ b/tools/selective_cfi/scfi_instr.hpp @@ -33,12 +33,14 @@ class SCFI_Instrument bool p_do_coloring=true, bool p_do_common_slow_path=true, bool p_do_jumps=true, - bool p_do_rets=true) + bool p_do_rets=true, + bool p_do_safefn=true) : firp(the_firp), do_coloring(p_do_coloring), do_common_slow_path(p_do_common_slow_path), do_jumps(p_do_jumps), do_rets(p_do_rets), + do_safefn(p_do_safefn), color_map(NULL) {} bool execute(); @@ -52,6 +54,7 @@ class SCFI_Instrument // helper libIRDB::Relocation_t* create_reloc(libIRDB::Instruction_t* insn); libIRDB::Relocation_t* FindRelocation(libIRDB::Instruction_t* insn, std::string type); + bool isSafeFunction(libIRDB::Instruction_t* insn); // add instrumentation bool add_scfi_instrumentation(libIRDB::Instruction_t* insn); @@ -76,6 +79,7 @@ class SCFI_Instrument bool do_common_slow_path; bool do_jumps; bool do_rets; + bool do_safefn; ColoredInstructionNonces_t *color_map;