//
// SMPProgram.cpp
//
// This module analyzes the whole program as needed for the
//   SMP project (Software Memory Protection).
//

#include <utility>
#include <list>
#include <set>
#include <vector>
#include <algorithm>

#include <cstring>
#include <cstdlib>

#include <pro.h>
#include <assert.h>
#include <ida.hpp>
#include <idp.hpp>
#include <auto.hpp>
#include <bytes.hpp>
#include <funcs.hpp>
#include <allins.hpp>
#include <intel.hpp>
#include <name.hpp>

#include "SMPDataFlowAnalysis.h"
#include "SMPStaticAnalyzer.h"
#include "SMPFunction.h"
#include "SMPBasicBlock.h"
#include "SMPInstr.h"
#include "SMPProgram.h"

// Set to 1 for debugging output
#define SMP_DEBUG 1
#define SMP_DEBUG_GLOBAL_GRANULARITY 0
#define SMP_DEBUG_OPTIMIZATIONS 1
#define SMP_DEBUG_OPTIMIZATIONS_VERBOSE 1
#define SMP_DEBUG_FUNC 0

// Compute fine-grained global static data boundaries?
#define SMP_COMPUTE_GLOBAL_GRANULARITY 1
// Distinguish between indexed and direct accesses in global granularity?
#define SMP_DETECT_INDEXED_ACCESSES 0
// Use type inference to compute optimizing annotations.
#define SMP_INFER_TYPES 1

ea_t LowestGlobalVarAddress;
ea_t HighestGlobalVarAddress;
ea_t LowestCodeAddress;
ea_t HighestCodeAddress;

// Does the instruction at InstAddr access the global data offset in GlobalAddr
//  using an index register?
bool MDIsIndexedAccess(ea_t InstAddr, ea_t GlobalAddr) {
	int InstLen = ua_ana0(InstAddr);
	bool DebugFlag = (InstAddr == 0x80502d3);
#if SMP_DETECT_INDEXED_ACCESSES
	if (0 >= InstLen)
#endif
		return false;
	for (int i = 0; i < UA_MAXOP; ++i) {
		op_t CurrOp = cmd.Operands[i];
		if ((CurrOp.type == o_mem) || (CurrOp.type == o_displ)) {
			if (GlobalAddr == CurrOp.addr) {
				if (CurrOp.hasSIB) {
					// GlobalAddr is referenced, and SIB byte is present, so we might have
					//  an indexed access to GlobalAddr.
					regnum_t IndexReg = sib_index(CurrOp);
					if (R_sp != IndexReg) {
						// R_sp is a SIB index dummy value; means no index register
						return true;
					}
				}
				else if (o_displ == CurrOp.type) { // index reg in reg field, not in SIB byte
					return true;
				}
				else if (DebugFlag) {
					msg("Failed to find index in operand: ");
					PrintOneOperand(CurrOp, 0, -1);
					msg("\n");
				}
			}
		}
	} // end for all operands
	return false;
} // end MDIsIndexedAccess()

// *****************************************************************
// Class SMPProgram
// *****************************************************************

// Constructor
SMPProgram::SMPProgram(void) {
	this->FuncMap.clear();
	return;
}

SMPProgram::~SMPProgram(void) {
	map<ea_t, SMPFunction *>::iterator FuncIter;
	for (FuncIter = this->FuncMap.begin(); FuncIter != this->FuncMap.end(); ++FuncIter) {
		delete (FuncIter->second);
	}
	return;
}

// Determine static global variable boundaries.
void SMPProgram::InitStaticDataTable(void) {
	segment_t *seg;
	char buf[MAXSTR];
	ea_t ea;
	flags_t ObjFlags;
	bool ReadOnlyFlag;

	// First, examine the data segments and collect info about static
	//   data, such as name/address/size.

	LowestGlobalVarAddress  = 0xffffffff;
	HighestGlobalVarAddress = 0x00000000;
	LowestCodeAddress  = 0xffffffff;
	HighestCodeAddress = 0x00000000;

	// Loop through all segments.
	for (int SegIndex = 0; SegIndex < get_segm_qty(); ++SegIndex) {
		char SegName[MAXSTR];
		seg = getnseg(SegIndex);
		ssize_t SegNameSize = get_segm_name(seg, SegName, sizeof(SegName) - 1);

		// We are only interested in the data segments of type
		// SEG_DATA, SEG_BSS and SEG_COMM.
		if ((seg->type == SEG_DATA) || (seg->type == SEG_BSS)
		    || (seg->type == SEG_COMM)) {
			// Loop through each of the segments we are interested in,
			//  examining all data objects (effective addresses).
			ReadOnlyFlag = ((seg->perm & SEGPERM_READ) && (!(seg->perm & SEGPERM_WRITE)));
#if SMP_DEBUG
			msg("Starting data segment of type %d", seg->type);
			if (SegNameSize > 0)
				msg(" SegName: %s\n", SegName);
			else
				msg("\n");
			if (ReadOnlyFlag) {
				msg("Read-only data segment.\n");
			}
#endif
			ea = seg->startEA;
			while (ea < seg->endEA) {
				ObjFlags = get_flags_novalue(ea);
				// Only process head bytes of data objects, i.e. isData().
				if (isData(ObjFlags)) {
				    // Compute the size of the data object.
					ea_t NextEA = ea;
				    do {
				       NextEA = nextaddr(NextEA);
					} while ((NextEA < seg->endEA) && (!isHead(get_flags_novalue(NextEA))));
				    size_t ObjSize = (size_t) (NextEA - ea);
					if (LowestGlobalVarAddress > ea)
						LowestGlobalVarAddress = ea;
					if (HighestGlobalVarAddress < NextEA)
						HighestGlobalVarAddress = NextEA - 1;
					// Get the data object name using its address.
				    char *TrueName = get_true_name(BADADDR, ea, buf, sizeof(buf));
					if (NULL == TrueName) {
						qstrncpy(buf, "SMP_dummy0", 12);
					}

				    // Record the name, address, size, and type info.
					struct GlobalVar VarTemp;
					VarTemp.addr = ea;
					VarTemp.size = ObjSize;
					VarTemp.ReadOnly = ReadOnlyFlag;
					VarTemp.flags = ObjFlags;
					qstrncpy(VarTemp.name, buf, MAXSTR - 1);
					VarTemp.FieldOffsets.clear();
#if SMP_COMPUTE_GLOBAL_GRANULARITY
					this->ComputeGlobalFieldOffsets(VarTemp);
#endif
					pair<ea_t, struct GlobalVar> TempItem(ea, VarTemp);
					this->GlobalVarTable.insert(TempItem);
					// Move on to next data object
					ea = NextEA;
				}
				else {
					ea = nextaddr(ea);
				}
			} // end while (ea < seg->endEA)
		} // end if (seg->type == SEG_DATA ...)
		else if (seg->type == SEG_CODE) {
			if (seg->startEA < LowestCodeAddress)
				LowestCodeAddress = seg->startEA;
			if (seg->endEA > HighestCodeAddress)
				HighestCodeAddress = seg->endEA - 1;
		} // end else if (seg->type === SEG_CODE)
		else {
#if SMP_DEBUG
			msg("Not processing segment of type %d SegName: %s\n",
				seg->type, SegName);
#endif
		}
	} // end for (int SegIndex = 0; ... )

	return;
} // end of SMPProgram::InitStaticDataTable()

// Find the direct and indexed accesses to offsets within each static data table entry.
//  Record the offset and kind of access (indexed or not) and conservatively mark the
//  field boundaries based on the unindexed accesses.
void SMPProgram::ComputeGlobalFieldOffsets(struct GlobalVar &CurrGlobal) {
	xrefblk_t xb;
	ea_t addr;
	size_t offset;
	bool DebugFlag = false;
	DebugFlag |= (0 == strcmp("spec_fd", CurrGlobal.name));
	for (addr = CurrGlobal.addr; addr < CurrGlobal.addr + CurrGlobal.size; ++addr) {
		bool Referenced = false;
		offset = addr - CurrGlobal.addr;
		pair<size_t, bool> TempOffset;
		TempOffset.first = offset;
		TempOffset.second = false; // No indexed accesses seen yet
		for (bool ok = xb.first_to(addr, XREF_ALL); ok; ok = xb.next_to()) {
			uchar XrefType = xb.type & XREF_MASK;
			if (xb.iscode) {
#if SMP_DEBUG_GLOBAL_GRANULARITY
				msg("WARNING: code xref to global data at %x\n", addr);
#endif
				;
			}
			else {
				if ((XrefType == dr_O) || (XrefType == dr_W) || (XrefType == dr_R)) {
#if SMP_DEBUG_GLOBAL_GRANULARITY
					SMPInstr TempInstr(xb.from);
					TempInstr.Analyze();
					msg("Data xref to global data %s at %x from code at %x %s\n",
						CurrGlobal.name, addr, xb.from, TempInstr.GetDisasm());
#endif
					Referenced = true;
					TempOffset.second |= MDIsIndexedAccess(xb.from, addr);
				}
				else {
#if SMP_DEBUG_GLOBAL_GRANULARITY
					msg("WARNING: Weird data xref type %d at %x\n", XrefType, xb.from);
#endif
					;
				}
			}
		} // end for (bool ok = iterate through xrefs ...)
		if (Referenced) {
			CurrGlobal.FieldOffsets.insert(TempOffset);
		}
	} // end for all addrs in current global

	return;
} // end of SMPProgram::ComputeGlobalFieldOffsets()

// Main program analysis driver. Goes through all functions and
//  analyzes all functions and global static data.
void SMPProgram::Analyze(void) {
	segment_t *seg;
	SMPFunction *CurrFunc;
	bool DebugFlag = false;
	long TotalTypedDefs = 0;
	long TotalUntypedDefs = 0;

	// Collect initial info about global static data objects.
	this->InitStaticDataTable();

	// Collect initial info about all functions.
	// Loop through all segments.
	for (int SegIndex = 0; SegIndex < get_segm_qty(); ++SegIndex) {
		char SegName[MAXSTR];
		seg = getnseg(SegIndex);
		ssize_t SegNameSize = get_segm_name(seg, SegName, sizeof(SegName) - 1);
		if (seg->type == SEG_CODE) {
#if SMP_DEBUG
			msg("Starting code segment");
			if (SegNameSize > 0)
				msg(" SegName: %s\n", SegName);
			else
				msg("\n");
#endif
			for (size_t FuncIndex = 0; FuncIndex < get_func_qty(); ++FuncIndex) {
				func_t *FuncInfo = getn_func(FuncIndex);

				// If more than one SEG_CODE segment, only process 
				//  functions within the current segment. Don't know
				//  if multiple code segments are possible, but
				//  get_func_qty() is for the whole program, not just
				//  the current segment.
				if (FuncInfo->startEA < seg->startEA) {
					// Already processed this func in earlier segment.
					continue;
				}
				else if (FuncInfo->startEA >= seg->endEA) {
					break;
				}

				// Create a function object.
				CurrFunc = NULL;
				CurrFunc = new SMPFunction(FuncInfo);
				pair<ea_t, SMPFunction *> TempFunc(FuncInfo->startEA, CurrFunc);
				this->FuncMap.insert(TempFunc);
				CurrFunc->Analyze();
				if (0 == strcmp("EvVar", CurrFunc->GetFuncName())) {
					DebugFlag = true;
				}
#if SMP_INFER_TYPES
				if (CurrFunc->HasGoodRTLs() && !CurrFunc->HasIndirectJumps()) {
#if SMP_DEBUG_OPTIMIZATIONS
					msg("Inferring types for function %s\n", CurrFunc->GetFuncName());
#endif
					CurrFunc->InferTypes();
					TotalTypedDefs += CurrFunc->GetTypedDefs();
					TotalUntypedDefs += CurrFunc->GetUntypedDefs();
#if SMP_DEBUG_OPTIMIZATIONS_VERBOSE
					if (DebugFlag) {
						CurrFunc->Dump();
						DebugFlag = false;
					}
#endif
				}
#endif // SMP_INFER_TYPES
			} // end for (size_t FuncIndex = 0; ...) 

		} // end if SEG_CODE segment
	} // end for all segments
	msg("Total Typed DEFs: %d\n", TotalTypedDefs);
	msg("Total Untyped DEFs: %d\n", TotalUntypedDefs);
	return;
} // end of SMPProgram::Analyze()

// Emit all annotations for the program.
void SMPProgram::EmitAnnotations(FILE *AnnotFile) {
	// Emit global static data annotations first.
	map<ea_t, struct GlobalVar>::iterator GlobalIter;
	for (GlobalIter = this->GlobalVarTable.begin(); GlobalIter != this->GlobalVarTable.end(); ++GlobalIter) {
	    // Output the name, address, size, and type info.
		struct GlobalVar TempGlobal = GlobalIter->second;
		// If we have an offset other than 0 but not 0, add 0 to offsets.
		pair<size_t,bool> FirstOffset = (*(TempGlobal.FieldOffsets.begin()));
		if (0 != FirstOffset.first) {
			pair<size_t,bool> TempOffset;
			TempOffset.first = 0;
			TempOffset.second = false;
			TempGlobal.FieldOffsets.insert(TempOffset);
#if SMP_DEBUG_GLOBAL_GRANULARITY
			msg("Inserted offset 0 for global var %s\n", TempGlobal.name);
#endif
		}
		unsigned long ParentReferentID = DataReferentID++;
		bool DirectAccesses = false;  // Any direct field accesses in this global?
		set<pair<size_t, bool>, LessOff>::iterator CurrOffset;
		for (CurrOffset = TempGlobal.FieldOffsets.begin(); CurrOffset != TempGlobal.FieldOffsets.end(); ++CurrOffset) {
			pair<size_t, bool> TempOffset = *CurrOffset;
			if (!TempOffset.second)
				DirectAccesses = true;
		}
		// If 0 is the only direct offset, the data is not structured.
		if (!DirectAccesses || (2 > TempGlobal.FieldOffsets.size())) {
			// No fields within object, or all fields were accessed through indices
			if (TempGlobal.ReadOnly) {
				qfprintf(AnnotFile, 
					"%10x %6d DATAREF GLOBAL %d %x PARENT %s  %s RO\n",
					0, TempGlobal.size, ParentReferentID, TempGlobal.addr,
					TempGlobal.name, DataTypes[get_optype_flags0(TempGlobal.flags) >> 20]);
			}
			else {
				qfprintf(AnnotFile, 
					"%10x %6d DATAREF GLOBAL %d %x PARENT %s  %s RW\n",
					0, TempGlobal.size, ParentReferentID, TempGlobal.addr,
					TempGlobal.name, DataTypes[get_optype_flags0(TempGlobal.flags) >> 20]);
			}
		}
		else { // structured object with fields
			// Put out annotation for whole struct first
			if (TempGlobal.ReadOnly) {
				qfprintf(AnnotFile, 
					"%10x %6d DATAREF GLOBAL %d %x PARENT %s  %s RO AGGREGATE\n",
					0, TempGlobal.size, ParentReferentID, TempGlobal.addr,
					TempGlobal.name, DataTypes[get_optype_flags0(TempGlobal.flags) >> 20]);
			}
			else {
				qfprintf(AnnotFile, 
					"%10x %6d DATAREF GLOBAL %d %x PARENT %s  %s RW AGGREGATE\n",
					0, TempGlobal.size, ParentReferentID, TempGlobal.addr,
					TempGlobal.name, DataTypes[get_optype_flags0(TempGlobal.flags) >> 20]);
			}
			// Now, emit an annotation for each field offset.
			set<pair<size_t,bool>, LessOff>::iterator FieldIter, TempIter;
			size_t counter = 1;
			size_t FieldSize;
			for (FieldIter = TempGlobal.FieldOffsets.begin(); FieldIter != TempGlobal.FieldOffsets.end(); ++FieldIter, ++counter) {
				pair<size_t,bool> CurrOffset = (*FieldIter);
				if (counter < TempGlobal.FieldOffsets.size()) {
					TempIter = FieldIter;
					++TempIter;
					pair<size_t,bool> TempOffset = (*TempIter);
					FieldSize = TempOffset.first - CurrOffset.first;
				}
				else {
					FieldSize = TempGlobal.size - CurrOffset.first;
				}
				qfprintf(AnnotFile, 
					"%10x %6d DATAREF GLOBAL %d %x CHILDOF %d OFFSET %d %s + %d FIELD",
					0, FieldSize, DataReferentID, TempGlobal.addr, ParentReferentID, 
					CurrOffset.first, TempGlobal.name, CurrOffset.first);
				if (CurrOffset.second) { // indexed accesses to this field
					qfprintf(AnnotFile, " INDEXED\n");
				}
				else { // only direct accesses to this field
					qfprintf(AnnotFile, " DIRECT\n");
				}
				++DataReferentID;
			}
		} // end if unstructured data ... else ... 
	} // end for all globals in the global var table

	// Loop through all functions and emit annotations for each.
	map<ea_t, SMPFunction *>::iterator FuncIter;
	for (FuncIter = this->FuncMap.begin(); FuncIter != this->FuncMap.end(); ++FuncIter) {
		SMPFunction *TempFunc = FuncIter->second;
		if (TempFunc == NULL) continue;
		if (TempFunc->GetReturnAddressStatus() == FUNC_UNKNOWN) {
			// add func name to set
			 set<const char*,LtStr>::const_iterator find = FuncNameSet.find(TempFunc->GetFuncName());
			 if (find == FuncNameSet.end()) {
				 FuncNameSet.insert(TempFunc->GetFuncName());
				 RecurseAndMarkRetAdd(TempFunc);
				 FuncNameSet.erase(TempFunc->GetFuncName()); //remove name
			 }
		}
		TempFunc->EmitAnnotations(AnnotFile);
	} // end for all functions
	return;
} // end of SMPProgram::EmitAnnotations()

/**
 * If a function is still marked FUNC_UNKNOWN at the time of emitting
 * annotations, this function traverses over the call graph rooted at this
 * function and checks if they are marked safe.
 */  
FuncType SMPProgram::RecurseAndMarkRetAdd(SMPFunction* FuncAttrib) {
	if (FuncAttrib->IsLeaf()) {
#if SMP_DEBUG_FUNC
		if (FuncAttrib->GetReturnAddressStatus()  == FUNC_UNKNOWN)
			msg(" Leaf Function %s found with status unknown", FuncAttrib->GetFuncName()); 
#endif
		assert(FuncAttrib->GetReturnAddressStatus()  != FUNC_UNKNOWN);
			return FuncAttrib->GetReturnAddressStatus();
	}
	vector<ea_t> CallTargets = FuncAttrib->GetCallTargets();
	for (size_t i = 0; i < CallTargets.size(); i++) {
		ea_t CallAddr = CallTargets[i];
		SMPFunction* ChildInstance = FuncMap[CallAddr];
		if (!ChildInstance) {
#if SMP_DEBUG_FUNC
			// if a call target doesnt have a SMPFunction instance note it down
			msg(" Function doesnt have SMPFunction instance at %x \n", CallAddr);
#endif
			continue;
		}
		switch (ChildInstance->GetReturnAddressStatus()) {
			case FUNC_SAFE:
				continue;

			case FUNC_UNSAFE:
				FuncAttrib->SetReturnAddressStatus(FUNC_UNSAFE);
#if SMP_DEBUG_FUNC
				// if a call target is unsafe note it down
				msg("Function marked as unsafe %s\n", FuncAttrib->GetFuncName());
#endif
				return FUNC_UNSAFE;

			case FUNC_UNKNOWN:
			{
				set<const char*,LtStr>::const_iterator find = FuncNameSet.find(ChildInstance->GetFuncName());
				if (find == FuncNameSet.end()) {
					FuncNameSet.insert(ChildInstance->GetFuncName());
					FuncType Type = RecurseAndMarkRetAdd(ChildInstance);
					FuncNameSet.erase(ChildInstance->GetFuncName());
					if (Type == FUNC_UNSAFE) {
						FuncAttrib->SetReturnAddressStatus(FUNC_UNSAFE);
#if SMP_DEBUG_FUNC
						// if a call target is unsafe note it down
						msg("Function marked as unsafe %s\n", FuncAttrib->GetFuncName());
#endif
						return FUNC_UNSAFE;
					}
				}
			}
		} // end switch on child return address status
	} // end for all call targets
#if SMP_DEBUG_FUNC
	// if a call target is safe, note it
	msg("Function marked as safe %s\n", FuncAttrib->GetFuncName());
#endif

	FuncAttrib->SetReturnAddressStatus(FUNC_SAFE);
	return FUNC_SAFE;	
} // end of SMPProgram::RecurseAndMarkRetAddr()

// Debug output dump.
void SMPProgram::Dump(void) {
	// Loop through all functions and call the debug Dump() for each.
	map<ea_t, SMPFunction *>::iterator FuncIter;
	for (FuncIter = this->FuncMap.begin(); FuncIter != this->FuncMap.end(); ++FuncIter) {
		SMPFunction *TempFunc = FuncIter->second;
		TempFunc->Dump();
	} // end for all functions
	return;
} // end of SMPProgram::Dump()