diff --git a/src/llvm/expr.c b/src/llvm/expr.c index 3dc3cd4..c43c72c 100644 --- a/src/llvm/expr.c +++ b/src/llvm/expr.c @@ -4,6 +4,7 @@ #include #include +#include BackendError impl_bitwise_operation(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, @@ -13,7 +14,6 @@ BackendError impl_bitwise_operation(LLVMBackendCompileUnit *unit, // TODO: resolve lhs and rhs or op LLVMValueRef rhs = NULL; LLVMValueRef lhs = NULL; - LLVMValueRef op = NULL; if (operation->impl.bitwise == BitwiseNot) { // single operand @@ -65,11 +65,10 @@ BackendError impl_logical_operation(LLVMBackendCompileUnit *unit, // TODO: resolve lhs and rhs or op LLVMValueRef rhs = NULL; LLVMValueRef lhs = NULL; - LLVMValueRef op = NULL; - if (operation->kind == BitwiseNot) { + if (operation->kind == LogicalNot) { // single operand - op = convert_integral_to_boolean(builder, op); + rhs = convert_integral_to_boolean(builder, rhs); } else { // two operands lhs = convert_integral_to_boolean(builder, lhs); @@ -78,7 +77,6 @@ BackendError impl_logical_operation(LLVMBackendCompileUnit *unit, switch (operation->impl.bitwise) { case LogicalAnd: - // TODO: convert to either 0 or 1 *llvm_result = LLVMBuildAnd(builder, lhs, rhs, "logical and"); break; case LogicalOr: @@ -95,19 +93,153 @@ BackendError impl_logical_operation(LLVMBackendCompileUnit *unit, return SUCCESS; } +static LLVMBool is_floating_point(LLVMValueRef value) { + LLVMTypeRef valueType = LLVMTypeOf(value); + LLVMTypeKind typeKind = LLVMGetTypeKind(valueType); + + return typeKind == LLVMFloatTypeKind || typeKind == LLVMHalfTypeKind || typeKind == LLVMDoubleTypeKind || + typeKind == LLVMFP128TypeKind; +} + +static LLVMBool is_integral(LLVMValueRef value) { + LLVMTypeRef valueType = LLVMTypeOf(value); + LLVMTypeKind typeKind = LLVMGetTypeKind(valueType); + + return typeKind == LLVMIntegerTypeKind; +} + +BackendError impl_relational_operation(LLVMBackendCompileUnit *unit, + LLVMLocalScope *scope, + LLVMBuilderRef builder, + Operation *operation, + LLVMValueRef *llvm_result) { + // TODO: resolve lhs and rhs or op + LLVMValueRef rhs = NULL; + LLVMValueRef lhs = NULL; + + if ((is_integral(lhs) && is_integral(rhs)) == 1) { + // integral type + LLVMIntPredicate operator = 0; + + switch (operation->impl.relational) { + case Equal: + operator = LLVMIntEQ; + break; + case Greater: + operator = LLVMIntSGT; + break; + case Less: + operator = LLVMIntSLT; + break; + } + + *llvm_result = LLVMBuildICmp(builder, operator, lhs, rhs, "integral comparison"); + } else if ((is_floating_point(lhs) && is_floating_point(rhs)) == 1) { + // integral type + LLVMRealPredicate operator = 0; + + switch (operation->impl.relational) { + case Equal: + operator = LLVMRealOEQ; + break; + case Greater: + operator = LLVMRealOGT; + break; + case Less: + operator = LLVMRealOLT; + break; + } + + *llvm_result = LLVMBuildFCmp(builder, operator, lhs, rhs, "floating point comparison"); + } else { + PANIC("invalid type for relational operator"); + } + + return SUCCESS; +} + +BackendError impl_arithmetic_operation(LLVMBackendCompileUnit *unit, + LLVMLocalScope *scope, + LLVMBuilderRef builder, + Operation *operation, + LLVMValueRef *llvm_result) { + // TODO: resolve lhs and rhs or op + LLVMValueRef rhs = NULL; + LLVMValueRef lhs = NULL; + + if ((is_integral(lhs) && is_integral(rhs)) == 1) { + // integral type + LLVMIntPredicate operator = 0; + + switch (operation->impl.arithmetic) { + case Add: + *llvm_result = LLVMBuildNSWAdd(builder, lhs, rhs, "signed integer addition"); + break; + case Sub: + *llvm_result = LLVMBuildNSWSub(builder, lhs, rhs, "signed integer subtraction"); + break; + case Mul: + *llvm_result = LLVMBuildNSWMul(builder, lhs, rhs, "signed integer multiply"); + break; + case Div: + *llvm_result = LLVMBuildSDiv(builder, lhs, rhs, "signed integer divide"); + break; + } + + } else if ((is_floating_point(lhs) && is_floating_point(rhs)) == 1) { + // integral type + LLVMRealPredicate operator = 0; + + switch (operation->impl.arithmetic) { + case Add: + *llvm_result = LLVMBuildFAdd(builder, lhs, rhs, "floating point addition"); + break; + case Sub: + *llvm_result = LLVMBuildFSub(builder, lhs, rhs, "floating point subtraction"); + break; + case Mul: + *llvm_result = LLVMBuildFMul(builder, lhs, rhs, "floating point multiply"); + break; + case Div: + *llvm_result = LLVMBuildFDiv(builder, lhs, rhs, "floating point divide"); + break; + } + + *llvm_result = LLVMBuildFCmp(builder, operator, lhs, rhs, "floating point comparison"); + } else { + PANIC("invalid type for arithmetic operator"); + } + + return SUCCESS; +} + BackendError impl_operation(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, LLVMBuilderRef builder, Operation *operation, LLVMValueRef *llvm_result) { + BackendError err; + switch (operation->kind) { case Bitwise: - impl_bitwise_operation(unit, scope, builder, operation, - llvm_result); + err = impl_bitwise_operation(unit, scope, builder, operation, + llvm_result); break; - case Logical: - impl_logical_operation(unit, scope, builder, operation, - llvm_result); + case Boolean: + err = impl_logical_operation(unit, scope, builder, operation, + llvm_result); break; + case Relational: + err = impl_relational_operation(unit, scope, builder, operation, + llvm_result); + break; + case Arithmetic: + err = impl_arithmetic_operation(unit, scope, builder, operation, + llvm_result); + break; + default: + PANIC("Invalid operator"); } + + return err; } BackendError impl_transmute(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, @@ -186,4 +318,4 @@ BackendError impl_expr(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, } return err; -} \ No newline at end of file +} diff --git a/src/llvm/expr.h b/src/llvm/expr.h index 204b807..c0c8d47 100644 --- a/src/llvm/expr.h +++ b/src/llvm/expr.h @@ -10,7 +10,8 @@ #include #include -BackendError impl_expr(LLVMBackendCompileUnit* unit, LLVMLocalScope* scope, - Expression* expr, LLVMValueRef* llvm_result); +BackendError impl_expr(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, + LLVMBuilderRef builder, Expression *expr, + LLVMValueRef *llvm_result); #endif // LLVM_BACKEND_EXPR_H