From 8603656e3efbeddd22c1e83531e1c1feac625112 Mon Sep 17 00:00:00 2001 From: servostar Date: Wed, 22 May 2024 20:27:40 +0200 Subject: [PATCH] added basic expression support --- src/ast/ast.c | 12 +++ src/ast/ast.h | 2 + src/llvm/backend.c | 10 +-- src/llvm/expr/build.c | 134 +++++++++++++++++++++++++++++ src/llvm/expr/build.h | 13 +++ src/llvm/function/function-types.h | 3 + src/llvm/function/function.c | 39 ++++++++- src/llvm/function/function.h | 7 ++ src/llvm/stmt/build.c | 89 +++++++++++++++++++ src/llvm/stmt/build.h | 11 +++ src/llvm/types/scope.c | 12 ++- src/llvm/types/scope.h | 15 +++- 12 files changed, 336 insertions(+), 11 deletions(-) create mode 100644 src/llvm/expr/build.c create mode 100644 src/llvm/expr/build.h create mode 100644 src/llvm/stmt/build.c create mode 100644 src/llvm/stmt/build.h diff --git a/src/ast/ast.c b/src/ast/ast.c index e503e1e..3002570 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -287,3 +287,15 @@ void AST_fprint_graphviz(FILE* stream, const struct AST_Node_t* root) { fprintf(stream, "}\n"); } + +AST_NODE_PTR AST_get_node_by_kind(AST_NODE_PTR owner, enum AST_SyntaxElement_t kind) { + for (size_t i = 0; i < owner->child_count; i++) { + AST_NODE_PTR child = AST_get_node(owner, i); + + if (child->kind == kind) { + return child; + } + } + + return NULL; +} \ No newline at end of file diff --git a/src/ast/ast.h b/src/ast/ast.h index ff8297b..5772675 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -213,4 +213,6 @@ void AST_visit_nodes_recurse(struct AST_Node_t *root, [[gnu::nonnull(1), gnu::nonnull(2)]] void AST_fprint_graphviz(FILE* stream, const struct AST_Node_t* node); +AST_NODE_PTR AST_get_node_by_kind(AST_NODE_PTR owner, enum AST_SyntaxElement_t kind); + #endif diff --git a/src/llvm/backend.c b/src/llvm/backend.c index 235f59c..a2fcef9 100644 --- a/src/llvm/backend.c +++ b/src/llvm/backend.c @@ -32,7 +32,6 @@ static BackendError llvm_backend_codegen(const AST_NODE_PTR module_node, void**) AST_NODE_PTR global_node = AST_get_node(module_node, i); GemstoneTypedefRef typedefref; - GemstoneFunRef funref; GArray* decls; switch (global_node->kind) { @@ -41,20 +40,19 @@ static BackendError llvm_backend_codegen(const AST_NODE_PTR module_node, void**) type_scope_append_type(global_scope, typedefref); break; case AST_Fun: - funref = fun_from_ast(global_scope, global_node); - type_scope_add_fun(global_scope, funref); + llvm_generate_function_implementation(global_scope, module, global_node); break; case AST_Decl: decls = declaration_from_ast(global_scope, global_node); for (size_t i = 0; i < decls->len; i++) { GemstoneDeclRef decl = ((GemstoneDeclRef*) decls->data)[i]; - type_scope_add_variable(global_scope, decl); - LLVMValueRef llvm_decl = NULL; - err = llvm_create_declaration(module, NULL, decl, &llvm_decl); + err = llvm_create_declaration(module, NULL, decl, &decl->llvm_value); if (err.kind != Success) break; + + type_scope_add_variable(global_scope, decl); } break; diff --git a/src/llvm/expr/build.c b/src/llvm/expr/build.c new file mode 100644 index 0000000..0bcba93 --- /dev/null +++ b/src/llvm/expr/build.c @@ -0,0 +1,134 @@ +#include +#include +#include +#include +#include +#include + +BackendError llvm_build_arithmetic_operation(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, enum AST_SyntaxElement_t operation, LLVMValueRef* yield) { + AST_NODE_PTR expr_lhs = AST_get_node(expr_node, 0); + AST_NODE_PTR expr_rhs = AST_get_node(expr_node, 1); + + LLVMValueRef llvm_lhs = NULL; + LLVMValueRef llvm_rhs = NULL; + BackendError err; + + err = llvm_build_expression(builder, scope, module, expr_lhs, &llvm_lhs); + if (err.kind != Success) + return err; + + err = llvm_build_expression(builder, scope, module, expr_rhs, &llvm_rhs); + if (err.kind != Success) + return err; + + switch (operation) { + case AST_Add: + *yield = LLVMBuildAdd(builder, llvm_lhs, llvm_rhs, "Addition"); + break; + case AST_Sub: + *yield = LLVMBuildSub(builder, llvm_lhs, llvm_rhs, "Subtraction"); + break; + case AST_Mul: + *yield = LLVMBuildMul(builder, llvm_lhs, llvm_rhs, "Multiplication"); + break; + case AST_Div: + *yield = LLVMBuildSDiv(builder, llvm_lhs, llvm_rhs, "Division"); + break; + default: + break; + } + + return new_backend_impl_error(Implementation, expr_node, "invalid arithmetic operation"); +} + +BackendError llvm_build_relational_operation(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, enum AST_SyntaxElement_t operation, LLVMValueRef* yield) { + AST_NODE_PTR expr_lhs = AST_get_node(expr_node, 0); + AST_NODE_PTR expr_rhs = AST_get_node(expr_node, 1); + + LLVMValueRef llvm_lhs = NULL; + LLVMValueRef llvm_rhs = NULL; + BackendError err; + + err = llvm_build_expression(builder, scope, module, expr_lhs, &llvm_lhs); + if (err.kind != Success) + return err; + + err = llvm_build_expression(builder, scope, module, expr_rhs, &llvm_rhs); + if (err.kind != Success) + return err; + + // TODO: make a difference between SignedInt, UnsignedInt and Float + switch (operation) { + case AST_Eq: + *yield = LLVMBuildICmp(builder, LLVMIntEQ, llvm_lhs, llvm_rhs, "Equal"); + break; + case AST_Greater: + *yield = LLVMBuildICmp(builder, LLVMIntSGT, llvm_lhs, llvm_rhs, "Greater"); + break; + case AST_Less: + *yield = LLVMBuildICmp(builder, LLVMIntSLT, llvm_lhs, llvm_rhs, "Less"); + break; + default: + break; + } + + return new_backend_impl_error(Implementation, expr_node, "invalid arithmetic operation"); +} + +BackendError llvm_build_expression(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef* yield) { + + switch (expr_node->kind) { + case AST_Ident: { + AST_NODE_PTR variable_name = AST_get_node(expr_node, 0); + GemstoneDeclRef decl = type_scope_get_variable(scope, variable_name->value); + *yield = decl->llvm_value; + } + break; + case AST_Int: { + AST_NODE_PTR constant = AST_get_node(expr_node, 0); + // TODO: type annotation needed + *yield = LLVMConstIntOfString(LLVMInt32Type(), constant->value, 10); + } + case AST_Float: { + AST_NODE_PTR constant = AST_get_node(expr_node, 0); + // TODO: type annotation needed + *yield = LLVMConstRealOfString(LLVMFloatType(), constant->value); + } + break; + case AST_Add: + case AST_Sub: + case AST_Mul: + case AST_Div: { + BackendError err = llvm_build_arithmetic_operation(builder, scope, module, expr_node, expr_node->kind, yield); + if (err.kind != Success) + return err; + } + case AST_Eq: + case AST_Greater: + case AST_Less: { + BackendError err = llvm_build_relational_operation(builder, scope, module, expr_node, expr_node->kind, yield); + if (err.kind != Success) + return err; + } + break; + } + + return SUCCESS; +} + +BackendError llvm_build_expression_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR exprlist_node, LLVMValueRef** yields) { + + if (exprlist_node->kind != AST_ExprList) { + return new_backend_impl_error(Implementation, exprlist_node, "expected expression list"); + } + + *yields = malloc(sizeof(LLVMValueRef) * exprlist_node->child_count); + + for (size_t i = 0; i < exprlist_node->child_count; i++) { + AST_NODE_PTR expr = AST_get_node(exprlist_node, 0); + + llvm_build_expression(builder, scope, module, expr, *yields + i); + } + + return SUCCESS; +} diff --git a/src/llvm/expr/build.h b/src/llvm/expr/build.h new file mode 100644 index 0000000..a4d2c28 --- /dev/null +++ b/src/llvm/expr/build.h @@ -0,0 +1,13 @@ + +#ifndef LLVM_EXPR_BUILD_H_ +#define LLVM_EXPR_BUILD_H_ + +#include "codegen/backend.h" +#include "llvm/types/scope.h" +#include + +BackendError llvm_build_expression(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef* yield); + +BackendError llvm_build_expression_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef** yields); + +#endif // LLVM_EXPR_BUILD_H_ diff --git a/src/llvm/function/function-types.h b/src/llvm/function/function-types.h index 0af0c1a..211bf54 100644 --- a/src/llvm/function/function-types.h +++ b/src/llvm/function/function-types.h @@ -2,6 +2,7 @@ #ifndef LLVM_TYPES_FUNCTION_TYPES_H_ #define LLVM_TYPES_FUNCTION_TYPES_H_ +#include #include #include @@ -21,6 +22,8 @@ typedef struct GemstoneParam_t { typedef struct GemstoneFun_t { const char* name; GArray* params; + LLVMTypeRef llvm_signature; + LLVMValueRef llvm_function; } GemstoneFun; typedef GemstoneFun* GemstoneFunRef; diff --git a/src/llvm/function/function.c b/src/llvm/function/function.c index 4e8c643..3623f79 100644 --- a/src/llvm/function/function.c +++ b/src/llvm/function/function.c @@ -1,7 +1,10 @@ +#include "codegen/backend.h" #include "llvm/function/function-types.h" +#include "llvm/stmt/build.h" #include #include +#include #include #include #include @@ -108,7 +111,7 @@ void fun_delete(const GemstoneFunRef fun) { free(fun); } -LLVMTypeRef get_gemstone_function_llvm_signature(LLVMContextRef context, GemstoneFunRef function) { +LLVMTypeRef llvm_generate_function_signature(LLVMContextRef context, GemstoneFunRef function) { unsigned int param_count = function->params->len; LLVMTypeRef* params = malloc(sizeof(LLVMTypeRef)); @@ -120,3 +123,37 @@ LLVMTypeRef get_gemstone_function_llvm_signature(LLVMContextRef context, Gemston return LLVMFunctionType(LLVMVoidType(), params, param_count, 0); } + +BackendError llvm_generate_function_implementation(TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node) { + LLVMContextRef context = LLVMGetModuleContext(module); + GemstoneFunRef gemstone_signature = fun_from_ast(scope, node); + + gemstone_signature->llvm_signature = llvm_generate_function_signature(context, gemstone_signature); + gemstone_signature->llvm_function = LLVMAddFunction(module, gemstone_signature->name, gemstone_signature->llvm_signature); + + type_scope_add_fun(scope, gemstone_signature); + + LLVMBasicBlockRef llvm_body = LLVMAppendBasicBlock(gemstone_signature->llvm_function, "body"); + LLVMBuilderRef llvm_builder = LLVMCreateBuilderInContext(context); + LLVMPositionBuilderAtEnd(llvm_builder, llvm_body); + + // create new function local scope + TypeScopeRef local_scope = type_scope_new(); + size_t local_scope_idx = type_scope_append_scope(scope, local_scope); + + for (size_t i = 0; i < node->child_count; i++) { + AST_NODE_PTR child_node = AST_get_node(node, i); + if (child_node->kind == AST_StmtList) { + llvm_build_statement_list(llvm_builder, local_scope, module, child_node); + } + } + + // automatic return at end of function + LLVMBuildRetVoid(llvm_builder); + + // dispose function local scope + type_scope_remove_scope(scope, local_scope_idx); + type_scope_delete(local_scope); + + return SUCCESS; +} diff --git a/src/llvm/function/function.h b/src/llvm/function/function.h index 3e6ceec..833414f 100644 --- a/src/llvm/function/function.h +++ b/src/llvm/function/function.h @@ -2,9 +2,12 @@ #ifndef LLVM_FUNCTION_H_ #define LLVM_FUNCTION_H_ +#include #include +#include #include #include +#include /** * @brief Convert an AST node into a function parameter struct @@ -29,4 +32,8 @@ GemstoneFunRef fun_from_ast(const TypeScopeRef scope, const AST_NODE_PTR node); */ void fun_delete(const GemstoneFunRef fun); +LLVMTypeRef llvm_generate_function_signature(LLVMContextRef context, GemstoneFunRef function); + +BackendError llvm_generate_function_implementation(TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR body); + #endif // LLVM_FUNCTION_H_ diff --git a/src/llvm/stmt/build.c b/src/llvm/stmt/build.c new file mode 100644 index 0000000..d9234fe --- /dev/null +++ b/src/llvm/stmt/build.c @@ -0,0 +1,89 @@ + +#include "llvm/function/function-types.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +BackendError llvm_build_statement(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR stmt_node) { + switch (stmt_node->kind) { + case AST_Decl: { + GArray* decls = declaration_from_ast(scope, stmt_node); + + for (size_t i = 0; i < decls->len; i++) { + GemstoneDeclRef decl = ((GemstoneDeclRef*) decls->data)[i]; + + BackendError err = llvm_create_declaration(module, builder, decl, &decl->llvm_value); + + if (err.kind != Success) + break; + + type_scope_add_variable(scope, decl); + } + + // TODO: make sure all decls are freed later + g_array_free(decls, FALSE); + } + break; + case AST_Assign: { + AST_NODE_PTR variable_name = AST_get_node(stmt_node, 0); + AST_NODE_PTR expression = AST_get_node(stmt_node, 1); + + LLVMValueRef yield = NULL; + BackendError err = llvm_build_expression(builder, scope, module, expression, &yield); + + GemstoneDeclRef variable = type_scope_get_variable(scope, variable_name->value); + + LLVMBuildStore(builder, yield, variable->llvm_value); + } + break; + case AST_Stmt: + llvm_build_statement(builder, scope, module, stmt_node); + break; + case AST_Call: { + AST_NODE_PTR name = AST_get_node(stmt_node, 0); + AST_NODE_PTR expr_list = AST_get_node_by_kind(stmt_node, AST_ExprList); + GemstoneFunRef function_signature = type_scope_get_fun_from_name(scope, name->value); + size_t arg_count = function_signature->params->len; + + LLVMValueRef* args = NULL; + BackendError err = llvm_build_expression_list(builder, scope, module, expr_list, &args); + + LLVMBuildCall2(builder, function_signature->llvm_signature, function_signature->llvm_function, args, arg_count, name->value); + } + break; + case AST_Def: + // TODO: implement definition + break; + case AST_While: + // TODO: implement while + break; + case AST_If: + // TODO: implement if + break; + case AST_IfElse: + // TODO: implement else if + break; + case AST_Else: + // TODO: implement else + break; + default: + ERROR("Invalid AST node: %s", AST_node_to_string(stmt_node)); + return new_backend_impl_error(Implementation, stmt_node, "AST is invalid"); + } + + return SUCCESS; +} + +BackendError llvm_build_statement_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node) { + for (size_t i = 0; i < node->child_count; i++) { + AST_NODE_PTR stmt_node = AST_get_node(node, i); + + llvm_build_statement(builder, scope, module, stmt_node); + } +} diff --git a/src/llvm/stmt/build.h b/src/llvm/stmt/build.h new file mode 100644 index 0000000..726813a --- /dev/null +++ b/src/llvm/stmt/build.h @@ -0,0 +1,11 @@ + +#ifndef LLVM_STMT_BUILD_H_ +#define LLVM_STMT_BUILD_H_ + +#include +#include +#include + +BackendError llvm_build_statement_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node); + +#endif // LLVM_STMT_BUILD_H_ diff --git a/src/llvm/types/scope.c b/src/llvm/types/scope.c index 445c5ee..15f7634 100644 --- a/src/llvm/types/scope.c +++ b/src/llvm/types/scope.c @@ -30,9 +30,15 @@ void type_scope_append_type(TypeScopeRef scope, GemstoneTypedefRef type) { g_array_append_val(scope->types, type); } -void type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope) { - g_array_append_val(scope->scopes, child_scope); - child_scope->parent = scope; +size_t type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child) { + child->parent = scope; + g_array_append_val(scope->scopes, child); + + return scope->scopes->len - 1; +} + +void type_scope_remove_scope(TypeScopeRef scope, size_t index) { + g_array_remove_index(scope->scopes, index); } GemstoneTypedefRef type_scope_get_type(TypeScopeRef scope, size_t index) { diff --git a/src/llvm/types/scope.h b/src/llvm/types/scope.h index a8beac2..84e37fd 100644 --- a/src/llvm/types/scope.h +++ b/src/llvm/types/scope.h @@ -2,6 +2,7 @@ #ifndef LLVM_TYPE_SCOPE_H_ #define LLVM_TYPE_SCOPE_H_ +#include #include #include #include @@ -17,6 +18,7 @@ typedef struct GemstoneDecl_t { const char* name; StorageQualifier storageQualifier; GemstoneTypeRef type; + LLVMValueRef llvm_value; } GemstoneDecl; typedef GemstoneDecl* GemstoneDeclRef; @@ -49,7 +51,16 @@ void type_scope_append_type(TypeScopeRef scope, GemstoneTypedefRef type); * @param child_scope */ [[gnu::nonnull(1), gnu::nonnull(2)]] -void type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope); +size_t type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope); + +/** + * @brief Remove a new child scope to this scope + * + * @param scope + * @param child_scope + */ +[[gnu::nonnull(1)]] +void type_scope_remove_scope(TypeScopeRef scope, size_t index); /** * @brief Get the type at the specified index in this scope level @@ -116,4 +127,6 @@ GemstoneFunRef type_scope_get_fun_from_name(TypeScopeRef scope, const char* name void type_scope_add_variable(TypeScopeRef scope, GemstoneDeclRef decl); +GemstoneDeclRef type_scope_get_variable(TypeScopeRef scope, const char *name); + #endif // LLVM_TYPE_SCOPE_H_