Skip to content
Snippets Groups Projects
SMPInstr.cpp 318 KiB
Newer Older
void SMPInstr::MDFindLoadFromStack(bool UseFP) {
	set<DefOrUse, LessDefUse>::iterator UseIter;
	op_t UseOp;

	if ((3 == this->OptType) && (this->HasSourceMemoryOperand())) {
		// Loads and stores are OptCategory 3. We want only loads from the stack.
		for (UseIter = this->GetFirstUse(); UseIter != this->GetLastUse(); ++UseIter) {
			UseOp = UseIter->GetOp();
			if (MDIsStackAccessOpnd(UseOp, UseFP)) {
				this->SetLoadFromStack();
				break;
			}
		}
	}
	return;
} // end of SMPInstr::MDFindLoadFromStack()

// Determine if instr is inherently signed load instruction.
//  True if sign or zero-extended; pass out mask bits if true.
bool SMPInstr::MDIsSignedLoad(unsigned short &SignMask) {
	unsigned short opcode = this->SMPcmd.itype;
	if (NN_movzx == opcode) {
		SignMask = FG_MASK_UNSIGNED;
	}
	else if (NN_movsx == opcode) {
		SignMask = FG_MASK_SIGNED;
	}
	else {
		return false;
	}
	return true;
}

// Infer sign, bit width, other type info for simple cases where all the info needed is
//  within the instruction or can be read from the FineGrainedStackTable in the SMPFunction.
// NOTE: Must be called after SSA analysis is complete.
void SMPInstr::MDSetWidthSignInfo(bool UseFP) {
	set<DefOrUse, LessDefUse>::iterator UseIter;
	set<DefOrUse, LessDefUse>::iterator DefIter;
	op_t UseOp, DefOp;
	struct FineGrainedInfo FGEntry;
	bool ValueWillChange;
	unsigned short SignMask, TempSign, WidthMask;
	int DefHashValue, UseHashValue;

	case1 = this->IsLoadFromStack();
	case2 = this->MDIsSignedLoad(SignMask);
	case3 = (7 == this->OptType);
	case4 = ((CALL == this->GetDataFlowType()) || (INDIR_CALL == this->GetDataFlowType()));

	// Case 1: Load from stack location.
		bool success = false;
		for (UseIter = this->GetFirstUse(); UseIter != this->GetLastUse(); ++UseIter) {
			UseOp = UseIter->GetOp();
			if (MDIsStackAccessOpnd(UseOp, UseFP)) {
				// Found the stack location being loaded into a register. Now we need
				//  to get the sign and width info from the fine grained stack frame
				//  analysis.
				success = this->GetBlock()->GetFunc()->MDGetFGStackLocInfo(this->address, UseOp, FGEntry);
				assert(success);
				// Now we have signedness info in FGEntry. We need to OR it into the register target of the load.
				if (FGEntry.SignMiscInfo == 0) 
					break; // nothing to OR in; save time
				for (DefIter = this->GetFirstDef(); DefIter != this->GetLastDef(); ++DefIter) {
					DefOp = DefIter->GetOp();
					if (o_reg == DefOp.type) {
						DefOp.reg = MDCanonicalizeSubReg(DefOp.reg);
						TempSign = FGEntry.SignMiscInfo & FG_MASK_SIGNEDNESS_BITS; // Get both sign bit flags
						DefHashValue = HashGlobalNameAndSSA(DefOp, DefIter->GetSSANum());
						if (this->BasicBlock->IsLocalName(DefOp)) {
							this->BasicBlock->UpdateDefSignMiscInfo(DefHashValue, TempSign);
						}
						else {
							this->BasicBlock->GetFunc()->UpdateDefSignMiscInfo(DefHashValue, TempSign);
						}
						break;  // Should be only one register target for stack load, and no flags are set.
					}
				}
				break; // Only concerned with the stack operand
			}
		}
		assert(success);
	} // end if this->IsLoadFromStack()

	// Case 2: Loads that are sign-extended or zero-extended imply signed and unsigned, respectively.
	//  NOTE: If from the stack, they were handled in Case 1, and the signedness of the stack location
	//  was recorded a long time ago in SMPFunction::FindOutgoingArgsSize();
		DefIter = this->GetFirstDef();
		while (DefIter != this->GetLastDef()) {
			// All non-memory DEFs besides the flags register should get the new SignMask ORed in.
			// On x86, there should only be one DEF for this move, and no flags, but we will generalize
			//  in case other architectures are odd.
			DefOp = DefIter->GetOp();
			if (!(IsMemOperand(DefOp) || MDIsFlagsReg(DefOp))) {
				DefOp.reg = MDCanonicalizeSubReg(DefOp.reg);
				DefHashValue = HashGlobalNameAndSSA(DefOp, DefIter->GetSSANum());
				if (this->BasicBlock->IsLocalName(DefOp)) {
					this->BasicBlock->UpdateDefSignMiscInfo(DefHashValue, SignMask);
				}
				else {
					this->BasicBlock->GetFunc()->UpdateDefSignMiscInfo(DefHashValue, SignMask);
				}
			}
			++DefIter;
		}

		// If the signed load is from memory, the only USEs are the memory
		//  operand and addressing registers. We do not want to claim that
		//  EBX is signed in the instruction movsx eax,[ebx]. Only the DEF
		//  register EAX and the memory location [EBX] are signed, and we
		//  have no idea where [EBX] is, so we punt on all USEs if we have
		//  a memory source operand.
		if (!(this->HasSourceMemoryOperand())) {
			UseIter = this->GetFirstUse();
			while (UseIter != this->GetLastUse()) {
				// All non-memory USEs besides the flags register should get the new SignMask ORed in.
				UseOp = UseIter->GetOp();
				if (!(IsMemOperand(UseOp) || MDIsFlagsReg(UseOp))) {
					UseOp.reg = MDCanonicalizeSubReg(UseOp.reg);
					UseHashValue = HashGlobalNameAndSSA(UseOp, UseIter->GetSSANum());
					if (this->BasicBlock->IsLocalName(UseOp)) {
						this->BasicBlock->UpdateUseSignMiscInfo(UseHashValue, SignMask);
					}
					else {
						this->BasicBlock->GetFunc()->UpdateUseSignMiscInfo(UseHashValue, SignMask);
					}
	// Case 3: multiplies and divides can be signed or unsigned.
	else if (case3) { // Multiplies and divides are type 7.
		if (this->MDIsSignedArithmetic()) {
			SignMask = FG_MASK_SIGNED;
		}
		else if (this->MDIsUnsignedArithmetic()) {
		else {
			SignMask = 0; // unknown, uninitialized
		}
		if (0 != SignMask) {
			DefIter = this->GetFirstDef();
			while (DefIter != this->GetLastDef()) {
				// All DEFs besides the flags register should get the new SignMask ORed in.
				DefOp = DefIter->GetOp();
				if ((DefOp.type == o_reg) && (!(DefOp.is_reg(X86_FLAGS_REG)))) {
					DefOp.reg = MDCanonicalizeSubReg(DefOp.reg);
					DefHashValue = HashGlobalNameAndSSA(DefOp, DefIter->GetSSANum());
					if (this->BasicBlock->IsLocalName(DefOp)) {
						this->BasicBlock->UpdateDefSignMiscInfo(DefHashValue, SignMask);
					}
					else {
						this->BasicBlock->GetFunc()->UpdateDefSignMiscInfo(DefHashValue, SignMask);
					}
			UseIter = this->GetFirstUse();
			while (UseIter != this->GetLastUse()) {
				// All USEs besides the flags register should get the new SignMask ORed in.
				UseOp = UseIter->GetOp();
				if ((UseOp.type == o_reg) && (!(UseOp.is_reg(X86_FLAGS_REG)))) {
					UseOp.reg = MDCanonicalizeSubReg(UseOp.reg);
					UseHashValue = HashGlobalNameAndSSA(UseOp, UseIter->GetSSANum());
					if (this->BasicBlock->IsLocalName(UseOp)) {
						this->BasicBlock->UpdateUseSignMiscInfo(UseHashValue, SignMask);
					}
					else {
						this->BasicBlock->GetFunc()->UpdateUseSignMiscInfo(UseHashValue, SignMask);
					}
		} // end if (0 != SignMask)
	} // end of case 3 (multiplies and divides)

	// Case 4: Calls to library functions can reveal the type of the return register.
	else if (case4) {
		// Get name of function called.
		string FuncName = this->GetTrimmedCalledFunctionName();

		// Get FG info, if any, for called function.
		GetLibFuncFGInfo(FuncName, FGEntry);

		// See if anything was returned in FGEntry.
		if ((FGEntry.SignMiscInfo != 0) || (FGEntry.SizeInfo != 0)) {
			// Need to update the FG info for the DEF of the return register.
			DefOp = InitOp;
			DefOp.type = o_reg;
			DefOp.reg = MD_RETURN_VALUE_REG;
			DefIter = this->FindDef(DefOp);
			assert(DefIter != this->GetLastDef());
			DefHashValue = HashGlobalNameAndSSA(DefOp, DefIter->GetSSANum());
			if (this->BasicBlock->IsLocalName(DefOp)) {
				this->BasicBlock->UpdateDefFGInfo(DefHashValue, FGEntry);
			}
			else {
				this->BasicBlock->GetFunc()->UpdateDefFGInfo(DefHashValue, FGEntry);
			}
		}
	} // end of case4 (function calls)
	// For all register DEFs and USEs, we should get the obvious register width info
	//  updated. Need to use the RTL operands to get accurate widths.
	SMPRegTransfer *CurrRT;
	for (size_t index = 0; index < this->RTL.GetCount(); ++index) {
		CurrRT = this->RTL.GetRT(index);
		DefOp = CurrRT->GetLeftOperand();
		// Avoid setting def width for case 2; we leave it as zero so that
		//  later uses can determine whether the zero-extension or sign-extension
		//  bits ever got used. See more discussion in EmitIntegerErrorAnnotations()
		//  for the CHECK TRUNCATION case.
		// NOTE: case2 can be set to true even in the case1/case2 overlap case that
		//  only passes through the case1 code above. This is intentional. We want
		//  to leave the DEF width set to 0 for all of case2 including the case1 overlap.
		if (!case2 && MDIsGeneralPurposeReg(DefOp)) {
			WidthMask = ComputeOperandBitWidthMask(DefOp, 0);
			DefOp.reg = MDCanonicalizeSubReg(DefOp.reg);
			DefIter = this->FindDef(DefOp);
			assert(DefIter != this->GetLastDef());
			DefHashValue = HashGlobalNameAndSSA(DefOp, DefIter->GetSSANum());
			if (this->BasicBlock->IsLocalName(DefOp)) {
				this->BasicBlock->UpdateDefWidthTypeInfo(DefHashValue, WidthMask);
			}
			else {
				this->BasicBlock->GetFunc()->UpdateDefWidthTypeInfo(DefHashValue, WidthMask);
			}
		}
		if (CurrRT->HasRightSubTree()) {
			this->MDSetRTLRegWidthInfo(CurrRT->GetRightTree());
		}
		else {
			UseOp = CurrRT->GetRightOperand();
			this->SetRTLUseOpRegWidthInfo(UseOp);
		}
	}  // end for all RTLs 

	return;
} // end of SMPInstr::MDSetWidthSignInfo()

// Infer sign from the SMP types for USEs and DEFs.
void SMPInstr::InferSignednessFromSMPTypes(bool UseFP) {
	// Start with registers only, infer that all kids of pointers are UNSIGNED.
	set<DefOrUse, LessDefUse>::iterator DefIter, UseIter;
	op_t DefOp, UseOp;
	int SSANum;
	int DefHashValue, UseHashValue;
	SMPOperandType DefType, UseType;
	unsigned short DefSignMiscInfo = FG_MASK_UNSIGNED, UseSignMiscInfo = FG_MASK_UNSIGNED;
	bool GlobalName;

	for (DefIter = this->GetFirstDef(); DefIter != this->GetLastDef(); ++DefIter) {
		DefOp = DefIter->GetOp();
		if (MDIsGeneralPurposeReg(DefOp)) {
			DefType = DefIter->GetType();
			if (IsDataPtr(DefType) || (CODEPTR == DefType)) {
				GlobalName = this->BasicBlock->GetFunc()->IsGlobalName(DefOp);
				SSANum = DefIter->GetSSANum();
				DefHashValue = HashGlobalNameAndSSA(DefOp, SSANum);
				if (GlobalName) {
					this->BasicBlock->GetFunc()->UpdateDefSignMiscInfo(DefHashValue, DefSignMiscInfo);
				}
				else {
					this->BasicBlock->UpdateDefSignMiscInfo(DefHashValue, DefSignMiscInfo);
				}
			}
		}
	}

	for (UseIter = this->GetFirstUse(); UseIter != this->GetLastUse(); ++UseIter) {
		UseOp = UseIter->GetOp();
		if (MDIsGeneralPurposeReg(UseOp)) {
			UseType = UseIter->GetType();
			if (IsDataPtr(UseType) || (CODEPTR == UseType)) {
				GlobalName = this->BasicBlock->GetFunc()->IsGlobalName(UseOp);
				SSANum = UseIter->GetSSANum();
				UseHashValue = HashGlobalNameAndSSA(UseOp, SSANum);
				if (GlobalName) {
					this->BasicBlock->GetFunc()->UpdateUseSignMiscInfo(UseHashValue, UseSignMiscInfo);
				}
				else {
					this->BasicBlock->UpdateUseSignMiscInfo(UseHashValue, UseSignMiscInfo);
				}
			}
		}
	}

	return;
} // end of SMPInstr::InferSignednessFromSMPTypes()



// Helper to set width info for a UseOp from an RTL
void SMPInstr::SetRTLUseOpRegWidthInfo(op_t UseOp) {
	unsigned short WidthMask;
	set<DefOrUse, LessDefUse>::iterator UseIter;
	unsigned int UseHashValue;

	if (MDIsGeneralPurposeReg(UseOp)) {
		WidthMask = ComputeOperandBitWidthMask(UseOp, 0);
		UseOp.reg = MDCanonicalizeSubReg(UseOp.reg);
		UseIter = this->FindUse(UseOp);
		assert(UseIter != this->GetLastUse());
		UseHashValue = HashGlobalNameAndSSA(UseOp, UseIter->GetSSANum());
		if (this->BasicBlock->IsLocalName(UseOp)) {
			this->BasicBlock->UpdateUseWidthTypeInfo(UseHashValue, WidthMask);
		}
		else {
			this->BasicBlock->GetFunc()->UpdateUseWidthTypeInfo(UseHashValue, WidthMask);
		}
	}

	return;
} // end of SMPInstr::SetRTLUseOpRegWidthInfo()

// Walk the RTL and update the register USE operands' width info.
void SMPInstr::MDSetRTLRegWidthInfo(SMPRegTransfer *CurrRT) {
	op_t UseOp;

	UseOp = CurrRT->GetLeftOperand();
	this->SetRTLUseOpRegWidthInfo(UseOp);
	if (CurrRT->HasRightSubTree()) {
		this->MDSetRTLRegWidthInfo(CurrRT->GetRightTree());
	}
	else {
		UseOp = CurrRT->GetRightOperand();
		this->SetRTLUseOpRegWidthInfo(UseOp);
	}

	return;
} // end of SMPInstr::MDSetRTLRegWidthInfo()

// Infer DEF, USE, and RTL SMPoperator types within the instruction based on the type
//  of operator, the type category of the instruction, and the previously known types 
//  of the operands.
bool SMPInstr::InferTypes(void) {
	bool changed = false;  // return value
	int TypeCategory = SMPTypeCategory[this->SMPcmd.itype];
	set<DefOrUse, LessDefUse>::iterator CurrDef;
	set<DefOrUse, LessDefUse>::iterator CurrUse;
	op_t DefOp = InitOp, UseOp = InitOp;
	bool DebugFlag = false;
	bool UseFP = this->BasicBlock->GetFunc()->UsesFramePointer();
	bool SafeFunc = this->BasicBlock->GetFunc()->IsSafe();
#if SMP_VERBOSE_DEBUG_INFER_TYPES
	DebugFlag |= (0 == strcmp("InputMove", this->BasicBlock->GetFunc()->GetFuncName()));
	if (DebugFlag) {
		msg("opcode: %d TypeCategory: %d\n", this->SMPcmd.itype, TypeCategory);
	}

	// If we are already finished with all types, return false.
	if (this->IsTypeInferenceComplete())
	if (this->AllDEFsTyped() && this->AllUSEsTyped()) {
		this->SetTypeInferenceComplete();
	if (this->HasDestMemoryOperand()) {
		changed |= this->MDFindPointerUse(this->MDGetMemDefOp(), UseFP);
	}
	if (this->HasSourceMemoryOperand()) {
		changed |= this->MDFindPointerUse(this->MDGetMemUseOp(), UseFP);
	}

	// The control flow instructions can be handled simply based on their type
	//  and do not need an RTL walk.
	SMPitype DFAType = this->GetDataFlowType();
	bool CallInst = ((DFAType == CALL) || (DFAType == INDIR_CALL));
	ushort IndirCallReg = R_none;
	if (DebugFlag) {
		msg("DFAType: %d  CategoryInferenceComplete: %d\n",
			DFAType, this->IsCategoryInferenceComplete());
	if (DFAType == INDIR_CALL) {
		op_t TargetOp = this->SMPcmd.Operands[0];
		if (TargetOp.type == o_reg)
			IndirCallReg = TargetOp.reg;
	}
	if ((DFAType >= JUMP) && (DFAType <= INDIR_CALL)) {
		// All USEs are either the flags (NUMERIC) or the target address (CODEPTR).
		//  The exceptions are the USE list for interrupt calls, which includes
		//  the caller-saved regs, and indirect calls through a memory
		//  operand, such as call [ebx+esi+20h], where the memory operand
		//  is a CODEPTR but the addressing registers are a BaseReg and
		//  IndexReg as in any other memory addressing, and the caller-saved
		//  regs on any call.
		CurrUse = this->GetFirstUse();
		while (CurrUse != this->GetLastUse()) {
			UseOp = CurrUse->GetOp();
			if (UseOp.is_reg(X86_FLAGS_REG))
				CurrUse = this->SetUseType(UseOp, NUMERIC);
			else if ((CurrUse->GetType() != CODEPTR)
				&& (!(this->MDIsInterruptCall() && (o_reg == UseOp.type)))
				&& (!(CallInst && MDIsCallerSavedReg(UseOp)))
				&& (!(this->HasSourceMemoryOperand() 
					&& (INDIR_CALL == this->GetDataFlowType())
					&& (o_reg == UseOp.type)))) {
				CurrUse = this->SetUseType(UseOp, CODEPTR);
				if (DFAType == CALL) {
					// If the call is to malloc(), then the DEF of the return
					//  register is of type HEAPPTR.
					changed |= this->MDFindMallocCall(UseOp);
				}
			else if ((CurrUse->GetType() != CODEPTR) && CallInst
				&& UseOp.is_reg(IndirCallReg)) {

				CurrUse = this->SetUseType(UseOp, CODEPTR);
			}
		this->SetTypeInferenceComplete();
		return true;
	}

	// First, see if we can infer something about DEFs and USEs just from the 
	//  type category of the instruction.
	if (!this->IsCategoryInferenceComplete()) {
		bool MemPropagate = false;
		switch (TypeCategory) {
			case 0: // no inference possible just from type category
			case 1: // no inference possible just from type category
			case 3:  // MOV instructions; inference will come from source to dest in RTL walk.
			case 5:  // binary arithmetic; inference will come in RTL walk.
			case 10:  // binary arithmetic; inference will come in RTL walk.
			case 11:  // push and pop instructions; inference will come in RTL walk.
			case 12:  // exchange instructions; inference will come in RTL walk.
				this->SetCategoryInferenceComplete();
				break;

			case 2: // Result type is always NUMERIC.
			case 7: // Result type is always NUMERIC.
			case 8: // Result type is always NUMERIC.
			case 9: // Result type is always NUMERIC.
			case 13: // Result type is always NUMERIC.
			case 14: // Result type is always NUMERIC.
			case 15: // Result type is always NUMERIC.
				CurrDef = this->GetFirstDef();
				while (CurrDef != this->GetLastDef()) {
					if (!IsEqType(NUMERIC, CurrDef->GetType())) {
						DefOp = CurrDef->GetOp();
						SSANum = CurrDef->GetSSANum();
						CurrDef = this->SetDefType(DefOp, NUMERIC);
						changed = true;
						// Be conservative and only propagate register DEFs and SAFE stack locs. We
						//  can improve this in the future. **!!**
						bool IsMemOp = (o_reg != DefOp.type);
						bool MemPropagate = MDIsStackAccessOpnd(DefOp, UseFP);
#if SMP_PROPAGATE_MEM_TYPES
						;
#else
						// Be conservative and only propagate register DEFs and SAFE stack locs.
						//  We can improve this in the future. **!!**
						MemPropagate = MemPropagate && SafeFunc;
#endif
						if ((o_reg == DefOp.type) || MemPropagate) {
							if (this->BasicBlock->IsLocalName(DefOp)) {
								(void) this->BasicBlock->PropagateLocalDefType(DefOp, NUMERIC,
									this->GetAddr(), SSANum, IsMemOp);
							}
							else { // global name
								this->BasicBlock->GetFunc()->ResetProcessedBlocks(); // set Processed to false
								(void) this->BasicBlock->PropagateGlobalDefType(DefOp, NUMERIC,
				this->SetCategoryInferenceComplete();
			case 4: // Unary INC, DEC, etc.: dest=source, so type remains the same
				assert(1 == this->RTL.GetCount());
				assert(this->RTL.GetRT(0)->HasRightSubTree());
				UseOp = this->RTL.GetRT(0)->GetLeftOperand(); // USE == DEF
				CurrUse = this->Uses.FindRef(UseOp);
				assert(CurrUse != this->GetLastUse());
				if (UNINIT != CurrUse->GetType()) {
					// Only one USE, and it has a type assigned, so assign that type
					// to the DEF.
					CurrDef = this->GetFirstDef();
					while (CurrDef != this->GetLastDef()) {
						// Two DEFs: EFLAGS is NUMERIC, dest==source
						DefOp = CurrDef->GetOp();
						SSANum = CurrDef->GetSSANum();
						if (DefOp.is_reg(X86_FLAGS_REG)) {
							; // SetImmedTypes already made it NUMERIC
							CurrDef = this->SetDefType(DefOp, CurrUse->GetType());
							// Be conservative and only propagate register DEFs and SAFE stack locs. We
							//  can improve this in the future. **!!**
							bool IsMemOp = (o_reg != DefOp.type);
							MemPropagate = MDIsStackAccessOpnd(DefOp, UseFP);
#if SMP_PROPAGATE_MEM_TYPES
							;
#else
							// Be conservative and only propagate register DEFs and SAFE stack locs.
							//  We can improve this in the future. **!!**
							MemPropagate = MemPropagate && SafeFunc;
#endif
							if ((o_reg == DefOp.type) || MemPropagate) {
								if (this->BasicBlock->IsLocalName(DefOp)) {
									(void) this->BasicBlock->PropagateLocalDefType(DefOp, CurrUse->GetType(),
										this->GetAddr(), SSANum, IsMemOp);
								}
								else { // global name
									this->BasicBlock->GetFunc()->ResetProcessedBlocks(); // set Processed to false
									(void) this->BasicBlock->PropagateGlobalDefType(DefOp, CurrUse->GetType(),
					this->SetCategoryInferenceComplete();
					changed = true;
					this->SetTypeInferenceComplete();
				}
				break;

			case 6: // Result is always POINTER
				DefOp = this->GetFirstDef()->GetOp();
				SSANum = this->GetFirstDef()->GetSSANum();
				CurrDef = this->SetDefType(DefOp, POINTER);
				this->SetCategoryInferenceComplete();
				changed = true;
				// Be conservative and only propagate register DEFs and SAFE stack locs. We
				//  can improve this in the future. **!!**
				IsMemOp = (o_reg != DefOp.type);
				MemPropagate = MDIsStackAccessOpnd(DefOp, UseFP);
#if SMP_PROPAGATE_MEM_TYPES
				;
#else
				// Be conservative and only propagate register DEFs and SAFE stack locs.
				//  We can improve this in the future. **!!**
				MemPropagate = MemPropagate && SafeFunc;
#endif
				if ((o_reg == DefOp.type) || MemPropagate)  {
					if (this->BasicBlock->IsLocalName(DefOp)) {
						(void) this->BasicBlock->PropagateLocalDefType(DefOp, POINTER,
							this->GetAddr(), SSANum, IsMemOp);
					}
					else { // global name
						this->BasicBlock->GetFunc()->ResetProcessedBlocks(); // set Processed to false
						(void) this->BasicBlock->PropagateGlobalDefType(DefOp, POINTER,
				break;

			default:
				msg("ERROR: Unknown type category for %s\n", this->GetDisasm());
				this->SetCategoryInferenceComplete();
				break;
		} // end switch on TypeCategory
	} // end if (!CategoryInference)

	// Walk the RTL and infer types based on operators and operands.
	if (DebugFlag) {
		msg("RTcount: %d\n", this->RTL.GetCount());
	}
	for (size_t index = 0; index < this->RTL.GetCount(); ++index) {
		SMPRegTransfer *CurrRT = this->RTL.GetRT(index);
		if (SMP_NULL_OPERATOR == CurrRT->GetOperator()) // nothing to infer
			continue;
clc5q's avatar
clc5q committed
		if (!(CurrRT->IsTypeInferenceComplete())) {
			changed |= this->InferOperatorType(CurrRT);
		}
		if (DebugFlag) {
			msg("returned from InferOperatorType\n");
		}
	} // end for all RTs in the RTL
	return changed;
} // end of SMPInstr::InferTypes()

// Infer the type of an operator within an RT based on the types of its operands and
//  based on the operator itself. Recurse down the tree if necessary.
// Return true if the operator type of the RT is updated.
bool SMPInstr::InferOperatorType(SMPRegTransfer *CurrRT) {
	bool updated = false;
	bool LeftNumeric, RightNumeric;
	bool LeftPointer, RightPointer;
	bool UseFP = this->BasicBlock->GetFunc()->UsesFramePointer();
	bool SafeFunc = this->BasicBlock->GetFunc()->IsSafe();
	set<DefOrUse, LessDefUse>::iterator CurrDef;
	set<DefOrUse, LessDefUse>::iterator CurrUse;
	set<DefOrUse, LessDefUse>::iterator LeftUse;
	set<DefOrUse, LessDefUse>::iterator RightUse;
	SMPOperandType LeftType = UNINIT;
	SMPOperandType RightType = UNINIT;
	SMPOperandType OperType = UNINIT;
	op_t UseOp = InitOp, DefOp = InitOp, LeftOp = InitOp, RightOp = InitOp;
	SMPoperator CurrOp = CurrRT->GetOperator();
	bool DebugFlag = false;
clc5q's avatar
clc5q committed
	bool TypeInferenceFinished = false;
#if SMP_VERBOSE_DEBUG_INFER_TYPES
#if 1
	DebugFlag |= (0 == strcmp("InputMove", this->BasicBlock->GetFunc()->GetFuncName()));
	DebugFlag = DebugFlag || ((this->address == 0x806453b) || (this->address == 0x806453e));
#if SMP_VERBOSE_DEBUG_INFER_TYPES
	if (DebugFlag) {
		msg("Entered InferOperatorType for CurrOp: %d\n", CurrOp);
	}
clc5q's avatar
clc5q committed

	if (CurrRT->IsTypeInferenceComplete()) {
		return updated;
	}

	switch (CurrOp) {
		case SMP_NULL_OPERATOR:
clc5q's avatar
clc5q committed
			TypeInferenceFinished = true;
			break;

		case SMP_CALL:  // CALL instruction
			if (UNINIT == CurrRT->GetOperatorType()) {
				CurrRT->SetOperatorType(CODEPTR, this);
				updated = true;
				UseOp = CurrRT->GetRightOperand();
				CurrUse = this->Uses.FindRef(UseOp);
				assert(CurrUse != this->GetLastUse());
				if (UNINIT == CurrUse->GetType()) {
					CurrUse = this->SetUseType(UseOp, CODEPTR);
				}
				else if (CODEPTR != CurrUse->GetType()) {
					msg("WARNING: call target is type %d, setting to CODEPTR at %x in %s\n",
						CurrUse->GetType(), this->GetAddr(), this->GetDisasm());
					CurrUse = this->SetUseType(UseOp, CODEPTR);
				}
clc5q's avatar
clc5q committed
			TypeInferenceFinished = true;
			break;

		case SMP_INPUT:  // input from port
			if (UNINIT == CurrRT->GetOperatorType()) {
				CurrRT->SetOperatorType(NUMERIC, this);
				updated = true;
			}
			break;

		case SMP_OUTPUT: // output to port
			if (UNINIT == CurrRT->GetOperatorType()) {
				CurrRT->SetOperatorType(NUMERIC, this);
				updated = true;
			}
			break;

		case SMP_SIGN_EXTEND:
		case SMP_ZERO_EXTEND:
clc5q's avatar
clc5q committed
			// Should we infer that all operands are NUMERIC?  !!!???!!!!
			break;

		case SMP_ADDRESS_OF: // take effective address
			if (UNINIT == CurrRT->GetOperatorType()) {
				CurrRT->SetOperatorType(POINTER, this);
				// Left operand is having its address taken, but we cannot infer what its
				//  type is.
				updated = true;
			}
			break;

		case SMP_U_LEFT_SHIFT: // unsigned left shift
		case SMP_S_LEFT_SHIFT: // signed left shift
		case SMP_U_RIGHT_SHIFT: // unsigned right shift
		case SMP_S_RIGHT_SHIFT: // signed right shift
		case SMP_ROTATE_LEFT:
		case SMP_ROTATE_LEFT_CARRY: // rotate left through carry
		case SMP_ROTATE_RIGHT:
		case SMP_ROTATE_RIGHT_CARRY: // rotate right through carry
		case SMP_U_MULTIPLY:
		case SMP_S_MULTIPLY:
		case SMP_U_DIVIDE:
		case SMP_S_DIVIDE:
		case SMP_U_REMAINDER:
		case SMP_BITWISE_NOT: // unary operator
		case SMP_BITWISE_XOR:
		case SMP_NEGATE:    // unary negation
		case SMP_S_COMPARE: // signed compare (subtraction-based)
		case SMP_U_COMPARE: // unsigned compare (AND-based)
		case SMP_LESS_THAN: // boolean test operators
		case SMP_GREATER_THAN:
		case SMP_LESS_EQUAL:
		case SMP_GREATER_EQUAL:
		case SMP_EQUAL:
		case SMP_NOT_EQUAL:
		case SMP_LOGICAL_AND:
		case SMP_LOGICAL_OR:
		case SMP_UNARY_NUMERIC_OPERATION:  // miscellaneous; produces NUMERIC result
		case SMP_BINARY_NUMERIC_OPERATION:  // miscellaneous; produces NUMERIC result
		case SMP_SYSTEM_OPERATION:   // for instructions such as CPUID, RDTSC, etc.; NUMERIC
		case SMP_UNARY_FLOATING_ARITHMETIC:  // all the same to our type system; all NUMERIC
		case SMP_BINARY_FLOATING_ARITHMETIC:  // all the same to our type system; all NUMERIC
		case SMP_REVERSE_SHIFT_U:   // all the same to our type system; all NUMERIC
		case SMP_SHUFFLE:   // all the same to our type system; all NUMERIC
		case SMP_COMPARE_EQ_AND_SET:   // packed compare for equality and set bits; all NUMERIC
		case SMP_COMPARE_GT_AND_SET:   // packed compare for greater-than and set bits; all NUMERIC
		case SMP_INTERLEAVE:  // interleave fields from two packed operands; NUMERIC
		case SMP_CONCATENATE:   // all the same to our type system; all NUMERIC
			if (UNINIT == CurrRT->GetOperatorType()) {
				CurrRT->SetOperatorType(NUMERIC, this);
				updated = true;
			}
			// Left operand should be NUMERIC if it exists.
			UseOp = CurrRT->GetLeftOperand();
			if (UseOp.type != o_void) {
				CurrUse = this->Uses.FindRef(UseOp);
				if (CurrUse == this->GetLastUse()) {
					msg("WARNING: Adding missing USE of ");
					PrintOperand(UseOp);
					msg(" in %s\n", this->GetDisasm());
					this->Uses.SetRef(UseOp, NUMERIC, -1);
					updated = true;
				}
				else if (UNINIT == CurrUse->GetType()) {
					CurrUse = this->SetUseType(UseOp, NUMERIC);
					updated = true;
				}
			}
			// Right operand should be NUMERIC if it exists.
			if (CurrRT->HasRightSubTree()) {
				// Recurse into subtree
				updated |= this->InferOperatorType(CurrRT->GetRightTree());
			}
			else {
				UseOp = CurrRT->GetRightOperand();
				if (UseOp.type != o_void) {
					CurrUse = this->Uses.FindRef(UseOp);
					if (CurrUse == this->GetLastUse()) {
						msg("WARNING: Adding missing USE of ");
						PrintOperand(UseOp);
						msg(" in %s\n", this->GetDisasm());
						this->Uses.SetRef(UseOp, NUMERIC, -1);
						updated = true;
					}
					else if (UNINIT == CurrUse->GetType()) {
						CurrUse = this->SetUseType(UseOp, NUMERIC);
						updated = true;
					}
				}
			}
			break;

		case SMP_INCREMENT:
		case SMP_DECREMENT:
			// The type of the right operand is propagated to the operator, or vice
			//  versa, whichever receives a type first.
			assert(!CurrRT->HasRightSubTree());
			UseOp = CurrRT->GetLeftOperand();
			assert(o_void != UseOp.type);
			CurrUse = this->Uses.FindRef(UseOp);
			if (CurrUse == this->GetLastUse()) {
				msg("WARNING: Adding missing USE of ");
				PrintOperand(UseOp);
				msg(" at %x in %s\n", this->GetAddr(), this->GetDisasm());
				this->Uses.SetRef(UseOp);
				updated = true;
				break;
			}
			if (UNINIT == CurrRT->GetOperatorType()) {
				if (UNINIT != CurrUse->GetType()) {
					// Propagate operand type up to the operator.
					CurrRT->SetOperatorType(CurrUse->GetType(), this);
					updated = true;
				}
			}
			else if (UNINIT == CurrUse->GetType()) {
				// Propagate operator type to operand.
				CurrUse = this->SetUseType(UseOp, CurrRT->GetOperatorType());
				updated = true;
			}
			break;

		case SMP_ADD:
		case SMP_ADD_CARRY:   // add with carry
		case SMP_BITWISE_AND:
		case SMP_BITWISE_OR:
clc5q's avatar
clc5q committed
			// Extract the current types of right and left operands and the operator.
			LeftOp = CurrRT->GetLeftOperand();
			CurrUse = this->Uses.FindRef(LeftOp);
			assert(CurrUse != this->GetLastUse()); // found it
			LeftType = CurrUse->GetType();
			if (CurrRT->HasRightSubTree()) {
				RightType = CurrRT->GetRightTree()->GetOperatorType();
			}
			else {
				RightOp = CurrRT->GetRightOperand();
				if (o_void == RightOp.type) {
					msg("ERROR: void operand in %s\n", this->GetDisasm());
					return false;
clc5q's avatar
clc5q committed
					CurrUse = this->Uses.FindRef(RightOp);
					if (CurrUse == this->GetLastUse()) {
						msg("WARNING: Adding missing USE of ");
clc5q's avatar
clc5q committed
						PrintOperand(RightOp);
						msg(" in %s\n", this->GetDisasm());
clc5q's avatar
clc5q committed
						this->Uses.SetRef(RightOp);
						RightType = CurrUse->GetType();
clc5q's avatar
clc5q committed
				}
			}

			// We have to know both operand types to infer the operator, or know the
			//  operator type to infer the operand types.
			if ((UNINIT == CurrRT->GetOperatorType()) 
				&& ((UNINIT == LeftType) || (UNINIT == RightType)))
				break;

			// If both operands are NUMERIC, operator and result are NUMERIC.
			// If one operand is NUMERIC and the other is a pointer type,
			//  then the ADD operator and the result will inherit this second type,
			//  while AND and OR operators will remain UNINIT (we don't know what
			//  type "ptr AND 0xfffffff8" has until we see how it is used).
			LeftNumeric = IsEqType(NUMERIC, LeftType);
			RightNumeric = IsEqType(NUMERIC, RightType);
			LeftPointer = IsDataPtr(LeftType);
			RightPointer = IsDataPtr(RightType);
clc5q's avatar
clc5q committed
			if (UNINIT == CurrRT->GetOperatorType()) {
				// Infer operator type from left and right operands.
				if (LeftNumeric && RightNumeric) {
					CurrRT->SetOperatorType(NUMERIC, this);
					updated = true;
clc5q's avatar
clc5q committed
					break;
				}
				else if (LeftNumeric || RightNumeric) {
					// ADD of NUMERIC to non-NUMERIC preserves non-NUMERIC type.
					// AND and OR operations should leave the operator UNINIT for now.
					if (LeftNumeric && (UNINIT != RightType) 
						&& ((SMP_ADD == CurrOp) || (SMP_ADD_CARRY == CurrOp))) {
						CurrRT->SetOperatorType(RightType, this);
						updated = true;
clc5q's avatar
clc5q committed
						break;
					else if (RightNumeric && (UNINIT != LeftType) 
						&& ((SMP_ADD == CurrOp) || (SMP_ADD_CARRY == CurrOp))) {
						CurrRT->SetOperatorType(LeftType, this);
						updated = true;
clc5q's avatar
clc5q committed
						break;
				else if (LeftPointer && RightPointer) {
					// Arithmetic on two pointers
					if ((SMP_ADD == CurrOp) || (SMP_ADD_CARRY == CurrOp)) {
						CurrRT->SetOperatorType(UNKNOWN, this);
					}
					else { // bitwise AND or OR of two pointers
						msg("WARNING: hash of two pointers at %x in %s\n",
							this->GetAddr(), this->GetDisasm());
						// hash operation? leave operator as UNINIT
clc5q's avatar
clc5q committed
					break;
				else if ((LeftPointer && IsEqType(RightType, PTROFFSET))
					|| (RightPointer && IsEqType(LeftType, PTROFFSET))) {
					// Arithmetic on PTR and PTROFFSET
					if ((SMP_ADD == CurrOp) || (SMP_ADD_CARRY == CurrOp)) {
						// We assume (A-B) is being added to B or vice versa **!!**
						CurrRT->SetOperatorType(POINTER, this);
					}
					else { // bitwise AND or OR of pointer and pointer difference
						msg("WARNING: hash of PTROFFSET and POINTER at %x in %s\n",
							this->GetAddr(), this->GetDisasm());
						// hash operation? leave operator as UNINIT
clc5q's avatar
clc5q committed
					break;
				}
			} // end if UNINIT operator type
			else { // operator has type other than UNINIT
				// We make add-with-carry and subtract-with-borrow exceptions
				//  to the type propagation. LeftOp could have POINTER type
				//  inferred later; these instructions can change the type of
				//  the register from POINTER to NUMERIC, unlike regular
				//  add and subtract opcodes.
				if ((UNINIT == LeftType)
					&& (SMP_ADD_CARRY != CurrOp)) {
clc5q's avatar
clc5q committed
					CurrUse = this->SetUseType(LeftOp, CurrRT->GetOperatorType());
					updated = true;
					assert(CurrUse != this->GetLastUse());
clc5q's avatar
clc5q committed
					break;
				}
				if (CurrRT->HasRightSubTree()) {
					// Must need to iterate through the right tree again, as the operator
					//  has been typed.
					if (UNINIT == RightType) {
						CurrRT->GetRightTree()->SetOperatorType(CurrRT->GetOperatorType(), this);
						updated = true;
					}
					updated |= this->InferOperatorType(CurrRT->GetRightTree());
clc5q's avatar
clc5q committed
					break;
				}
				else { // right operand; propagate operator type if needed
clc5q's avatar
clc5q committed
					if (UNINIT == RightType) {
						CurrUse = this->SetUseType(RightOp, CurrRT->GetOperatorType());
						updated = true;
						assert(CurrUse != this->GetLastUse());
clc5q's avatar
clc5q committed
						break;
		case SMP_SUBTRACT_BORROW:  // subtract with borrow
			// Extract the current types of right and left operands and the operator.
			OperType = CurrRT->GetOperatorType();
			LeftOp = CurrRT->GetLeftOperand();
			LeftUse = this->Uses.FindRef(LeftOp);
			assert(LeftUse != this->GetLastUse()); // found it
			LeftType = LeftUse->GetType();
			if (CurrRT->HasRightSubTree()) {
				RightType = CurrRT->GetRightTree()->GetOperatorType();
			}
			else {
				RightOp = CurrRT->GetRightOperand();
				if (o_void == RightOp.type) {
					msg("ERROR: void operand in %s\n", this->GetDisasm());
					return false;
				}
				else {
					RightUse = this->Uses.FindRef(RightOp);
					if (RightUse == this->GetLastUse()) {
						msg("WARNING: Adding missing USE of ");
						PrintOperand(RightOp);
						msg(" in %s\n", this->GetDisasm());
						this->Uses.SetRef(RightOp);
						updated = true;
						break;
					}
					else {
						RightType = RightUse->GetType();
					}
				}
			}
			// If left operand is NUMERIC, operator is NUMERIC.
			LeftNumeric = IsEqType(NUMERIC, LeftType);
			RightNumeric = IsEqType(NUMERIC, RightType);
			LeftPointer = IsDataPtr(LeftType);
			RightPointer = IsDataPtr(RightType);
			if (LeftNumeric) {
				// Subtracting anything from a NUMERIC leaves it NUMERIC.
				if (UNINIT == OperType) {
					CurrRT->SetOperatorType(NUMERIC, this);
					updated = true;
				}
				else if (NUMERIC != OperType) {
					msg("ERROR: SMP_SUBTRACT from NUMERIC should be NUMERIC operator.");
					msg(" Operator type is %d in: %s\n", OperType, this->GetDisasm());
				}
				if (!RightNumeric) {
					// Right operand is being used as a NUMERIC, so propagate NUMERIC to it.
					if (CurrRT->HasRightSubTree()) {
						CurrRT->GetRightTree()->SetOperatorType(NUMERIC, this);
					}
					else {
						RightUse = this->SetUseType(RightOp, NUMERIC);
					}
					updated = true;
				}
			} // end if LeftNumeric
			else if (LeftPointer) {
				if (UNINIT == OperType) {
					// If we subtract another pointer type, we produce PTROFFSET.
					if (RightPointer) {
						CurrRT->SetOperatorType(PTROFFSET, this);
						updated = true;
					}
					else if (RightType == PTROFFSET) {
						// We assume B - (B - A) == A    **!!**
						CurrRT->SetOperatorType(POINTER, this);
						msg("WARNING: PTR - PTROFFSET produces PTR in %s\n", this->GetDisasm());
						updated = true;
					}