added basic expression support

This commit is contained in:
Sven Vogel 2024-05-22 20:27:40 +02:00
parent 00089a4939
commit 8603656e3e
12 changed files with 336 additions and 11 deletions

View File

@ -287,3 +287,15 @@ void AST_fprint_graphviz(FILE* stream, const struct AST_Node_t* root) {
fprintf(stream, "}\n");
}
AST_NODE_PTR AST_get_node_by_kind(AST_NODE_PTR owner, enum AST_SyntaxElement_t kind) {
for (size_t i = 0; i < owner->child_count; i++) {
AST_NODE_PTR child = AST_get_node(owner, i);
if (child->kind == kind) {
return child;
}
}
return NULL;
}

View File

@ -213,4 +213,6 @@ void AST_visit_nodes_recurse(struct AST_Node_t *root,
[[gnu::nonnull(1), gnu::nonnull(2)]]
void AST_fprint_graphviz(FILE* stream, const struct AST_Node_t* node);
AST_NODE_PTR AST_get_node_by_kind(AST_NODE_PTR owner, enum AST_SyntaxElement_t kind);
#endif

View File

@ -32,7 +32,6 @@ static BackendError llvm_backend_codegen(const AST_NODE_PTR module_node, void**)
AST_NODE_PTR global_node = AST_get_node(module_node, i);
GemstoneTypedefRef typedefref;
GemstoneFunRef funref;
GArray* decls;
switch (global_node->kind) {
@ -41,20 +40,19 @@ static BackendError llvm_backend_codegen(const AST_NODE_PTR module_node, void**)
type_scope_append_type(global_scope, typedefref);
break;
case AST_Fun:
funref = fun_from_ast(global_scope, global_node);
type_scope_add_fun(global_scope, funref);
llvm_generate_function_implementation(global_scope, module, global_node);
break;
case AST_Decl:
decls = declaration_from_ast(global_scope, global_node);
for (size_t i = 0; i < decls->len; i++) {
GemstoneDeclRef decl = ((GemstoneDeclRef*) decls->data)[i];
type_scope_add_variable(global_scope, decl);
LLVMValueRef llvm_decl = NULL;
err = llvm_create_declaration(module, NULL, decl, &llvm_decl);
err = llvm_create_declaration(module, NULL, decl, &decl->llvm_value);
if (err.kind != Success)
break;
type_scope_add_variable(global_scope, decl);
}
break;

134
src/llvm/expr/build.c Normal file
View File

@ -0,0 +1,134 @@
#include <llvm/types/scope.h>
#include <ast/ast.h>
#include <codegen/backend.h>
#include <llvm-c/Core.h>
#include <llvm-c/Types.h>
#include <llvm/expr/build.h>
BackendError llvm_build_arithmetic_operation(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, enum AST_SyntaxElement_t operation, LLVMValueRef* yield) {
AST_NODE_PTR expr_lhs = AST_get_node(expr_node, 0);
AST_NODE_PTR expr_rhs = AST_get_node(expr_node, 1);
LLVMValueRef llvm_lhs = NULL;
LLVMValueRef llvm_rhs = NULL;
BackendError err;
err = llvm_build_expression(builder, scope, module, expr_lhs, &llvm_lhs);
if (err.kind != Success)
return err;
err = llvm_build_expression(builder, scope, module, expr_rhs, &llvm_rhs);
if (err.kind != Success)
return err;
switch (operation) {
case AST_Add:
*yield = LLVMBuildAdd(builder, llvm_lhs, llvm_rhs, "Addition");
break;
case AST_Sub:
*yield = LLVMBuildSub(builder, llvm_lhs, llvm_rhs, "Subtraction");
break;
case AST_Mul:
*yield = LLVMBuildMul(builder, llvm_lhs, llvm_rhs, "Multiplication");
break;
case AST_Div:
*yield = LLVMBuildSDiv(builder, llvm_lhs, llvm_rhs, "Division");
break;
default:
break;
}
return new_backend_impl_error(Implementation, expr_node, "invalid arithmetic operation");
}
BackendError llvm_build_relational_operation(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, enum AST_SyntaxElement_t operation, LLVMValueRef* yield) {
AST_NODE_PTR expr_lhs = AST_get_node(expr_node, 0);
AST_NODE_PTR expr_rhs = AST_get_node(expr_node, 1);
LLVMValueRef llvm_lhs = NULL;
LLVMValueRef llvm_rhs = NULL;
BackendError err;
err = llvm_build_expression(builder, scope, module, expr_lhs, &llvm_lhs);
if (err.kind != Success)
return err;
err = llvm_build_expression(builder, scope, module, expr_rhs, &llvm_rhs);
if (err.kind != Success)
return err;
// TODO: make a difference between SignedInt, UnsignedInt and Float
switch (operation) {
case AST_Eq:
*yield = LLVMBuildICmp(builder, LLVMIntEQ, llvm_lhs, llvm_rhs, "Equal");
break;
case AST_Greater:
*yield = LLVMBuildICmp(builder, LLVMIntSGT, llvm_lhs, llvm_rhs, "Greater");
break;
case AST_Less:
*yield = LLVMBuildICmp(builder, LLVMIntSLT, llvm_lhs, llvm_rhs, "Less");
break;
default:
break;
}
return new_backend_impl_error(Implementation, expr_node, "invalid arithmetic operation");
}
BackendError llvm_build_expression(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef* yield) {
switch (expr_node->kind) {
case AST_Ident: {
AST_NODE_PTR variable_name = AST_get_node(expr_node, 0);
GemstoneDeclRef decl = type_scope_get_variable(scope, variable_name->value);
*yield = decl->llvm_value;
}
break;
case AST_Int: {
AST_NODE_PTR constant = AST_get_node(expr_node, 0);
// TODO: type annotation needed
*yield = LLVMConstIntOfString(LLVMInt32Type(), constant->value, 10);
}
case AST_Float: {
AST_NODE_PTR constant = AST_get_node(expr_node, 0);
// TODO: type annotation needed
*yield = LLVMConstRealOfString(LLVMFloatType(), constant->value);
}
break;
case AST_Add:
case AST_Sub:
case AST_Mul:
case AST_Div: {
BackendError err = llvm_build_arithmetic_operation(builder, scope, module, expr_node, expr_node->kind, yield);
if (err.kind != Success)
return err;
}
case AST_Eq:
case AST_Greater:
case AST_Less: {
BackendError err = llvm_build_relational_operation(builder, scope, module, expr_node, expr_node->kind, yield);
if (err.kind != Success)
return err;
}
break;
}
return SUCCESS;
}
BackendError llvm_build_expression_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR exprlist_node, LLVMValueRef** yields) {
if (exprlist_node->kind != AST_ExprList) {
return new_backend_impl_error(Implementation, exprlist_node, "expected expression list");
}
*yields = malloc(sizeof(LLVMValueRef) * exprlist_node->child_count);
for (size_t i = 0; i < exprlist_node->child_count; i++) {
AST_NODE_PTR expr = AST_get_node(exprlist_node, 0);
llvm_build_expression(builder, scope, module, expr, *yields + i);
}
return SUCCESS;
}

13
src/llvm/expr/build.h Normal file
View File

@ -0,0 +1,13 @@
#ifndef LLVM_EXPR_BUILD_H_
#define LLVM_EXPR_BUILD_H_
#include "codegen/backend.h"
#include "llvm/types/scope.h"
#include <llvm-c/Types.h>
BackendError llvm_build_expression(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef* yield);
BackendError llvm_build_expression_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR expr_node, LLVMValueRef** yields);
#endif // LLVM_EXPR_BUILD_H_

View File

@ -2,6 +2,7 @@
#ifndef LLVM_TYPES_FUNCTION_TYPES_H_
#define LLVM_TYPES_FUNCTION_TYPES_H_
#include <llvm-c/Types.h>
#include <llvm/types/structs.h>
#include <glib.h>
@ -21,6 +22,8 @@ typedef struct GemstoneParam_t {
typedef struct GemstoneFun_t {
const char* name;
GArray* params;
LLVMTypeRef llvm_signature;
LLVMValueRef llvm_function;
} GemstoneFun;
typedef GemstoneFun* GemstoneFunRef;

View File

@ -1,7 +1,10 @@
#include "codegen/backend.h"
#include "llvm/function/function-types.h"
#include "llvm/stmt/build.h"
#include <ast/ast.h>
#include <llvm-c/Core.h>
#include <llvm-c/Types.h>
#include <llvm/types/scope.h>
#include <llvm/function/function.h>
#include <llvm/types/type.h>
@ -108,7 +111,7 @@ void fun_delete(const GemstoneFunRef fun) {
free(fun);
}
LLVMTypeRef get_gemstone_function_llvm_signature(LLVMContextRef context, GemstoneFunRef function) {
LLVMTypeRef llvm_generate_function_signature(LLVMContextRef context, GemstoneFunRef function) {
unsigned int param_count = function->params->len;
LLVMTypeRef* params = malloc(sizeof(LLVMTypeRef));
@ -120,3 +123,37 @@ LLVMTypeRef get_gemstone_function_llvm_signature(LLVMContextRef context, Gemston
return LLVMFunctionType(LLVMVoidType(), params, param_count, 0);
}
BackendError llvm_generate_function_implementation(TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node) {
LLVMContextRef context = LLVMGetModuleContext(module);
GemstoneFunRef gemstone_signature = fun_from_ast(scope, node);
gemstone_signature->llvm_signature = llvm_generate_function_signature(context, gemstone_signature);
gemstone_signature->llvm_function = LLVMAddFunction(module, gemstone_signature->name, gemstone_signature->llvm_signature);
type_scope_add_fun(scope, gemstone_signature);
LLVMBasicBlockRef llvm_body = LLVMAppendBasicBlock(gemstone_signature->llvm_function, "body");
LLVMBuilderRef llvm_builder = LLVMCreateBuilderInContext(context);
LLVMPositionBuilderAtEnd(llvm_builder, llvm_body);
// create new function local scope
TypeScopeRef local_scope = type_scope_new();
size_t local_scope_idx = type_scope_append_scope(scope, local_scope);
for (size_t i = 0; i < node->child_count; i++) {
AST_NODE_PTR child_node = AST_get_node(node, i);
if (child_node->kind == AST_StmtList) {
llvm_build_statement_list(llvm_builder, local_scope, module, child_node);
}
}
// automatic return at end of function
LLVMBuildRetVoid(llvm_builder);
// dispose function local scope
type_scope_remove_scope(scope, local_scope_idx);
type_scope_delete(local_scope);
return SUCCESS;
}

View File

@ -2,9 +2,12 @@
#ifndef LLVM_FUNCTION_H_
#define LLVM_FUNCTION_H_
#include <codegen/backend.h>
#include <ast/ast.h>
#include <llvm-c/Types.h>
#include <llvm/function/function-types.h>
#include <llvm/types/scope.h>
#include <llvm-c/Core.h>
/**
* @brief Convert an AST node into a function parameter struct
@ -29,4 +32,8 @@ GemstoneFunRef fun_from_ast(const TypeScopeRef scope, const AST_NODE_PTR node);
*/
void fun_delete(const GemstoneFunRef fun);
LLVMTypeRef llvm_generate_function_signature(LLVMContextRef context, GemstoneFunRef function);
BackendError llvm_generate_function_implementation(TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR body);
#endif // LLVM_FUNCTION_H_

89
src/llvm/stmt/build.c Normal file
View File

@ -0,0 +1,89 @@
#include "llvm/function/function-types.h"
#include <codegen/backend.h>
#include <llvm-c/Core.h>
#include <llvm-c/Types.h>
#include <llvm/decl/variable.h>
#include <llvm/types/scope.h>
#include <ast/ast.h>
#include <llvm/stmt/build.h>
#include <llvm/expr/build.h>
#include <sys/log.h>
BackendError llvm_build_statement(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR stmt_node) {
switch (stmt_node->kind) {
case AST_Decl: {
GArray* decls = declaration_from_ast(scope, stmt_node);
for (size_t i = 0; i < decls->len; i++) {
GemstoneDeclRef decl = ((GemstoneDeclRef*) decls->data)[i];
BackendError err = llvm_create_declaration(module, builder, decl, &decl->llvm_value);
if (err.kind != Success)
break;
type_scope_add_variable(scope, decl);
}
// TODO: make sure all decls are freed later
g_array_free(decls, FALSE);
}
break;
case AST_Assign: {
AST_NODE_PTR variable_name = AST_get_node(stmt_node, 0);
AST_NODE_PTR expression = AST_get_node(stmt_node, 1);
LLVMValueRef yield = NULL;
BackendError err = llvm_build_expression(builder, scope, module, expression, &yield);
GemstoneDeclRef variable = type_scope_get_variable(scope, variable_name->value);
LLVMBuildStore(builder, yield, variable->llvm_value);
}
break;
case AST_Stmt:
llvm_build_statement(builder, scope, module, stmt_node);
break;
case AST_Call: {
AST_NODE_PTR name = AST_get_node(stmt_node, 0);
AST_NODE_PTR expr_list = AST_get_node_by_kind(stmt_node, AST_ExprList);
GemstoneFunRef function_signature = type_scope_get_fun_from_name(scope, name->value);
size_t arg_count = function_signature->params->len;
LLVMValueRef* args = NULL;
BackendError err = llvm_build_expression_list(builder, scope, module, expr_list, &args);
LLVMBuildCall2(builder, function_signature->llvm_signature, function_signature->llvm_function, args, arg_count, name->value);
}
break;
case AST_Def:
// TODO: implement definition
break;
case AST_While:
// TODO: implement while
break;
case AST_If:
// TODO: implement if
break;
case AST_IfElse:
// TODO: implement else if
break;
case AST_Else:
// TODO: implement else
break;
default:
ERROR("Invalid AST node: %s", AST_node_to_string(stmt_node));
return new_backend_impl_error(Implementation, stmt_node, "AST is invalid");
}
return SUCCESS;
}
BackendError llvm_build_statement_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node) {
for (size_t i = 0; i < node->child_count; i++) {
AST_NODE_PTR stmt_node = AST_get_node(node, i);
llvm_build_statement(builder, scope, module, stmt_node);
}
}

11
src/llvm/stmt/build.h Normal file
View File

@ -0,0 +1,11 @@
#ifndef LLVM_STMT_BUILD_H_
#define LLVM_STMT_BUILD_H_
#include <llvm/types/scope.h>
#include <codegen/backend.h>
#include <llvm-c/Types.h>
BackendError llvm_build_statement_list(LLVMBuilderRef builder, TypeScopeRef scope, LLVMModuleRef module, AST_NODE_PTR node);
#endif // LLVM_STMT_BUILD_H_

View File

@ -30,9 +30,15 @@ void type_scope_append_type(TypeScopeRef scope, GemstoneTypedefRef type) {
g_array_append_val(scope->types, type);
}
void type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope) {
g_array_append_val(scope->scopes, child_scope);
child_scope->parent = scope;
size_t type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child) {
child->parent = scope;
g_array_append_val(scope->scopes, child);
return scope->scopes->len - 1;
}
void type_scope_remove_scope(TypeScopeRef scope, size_t index) {
g_array_remove_index(scope->scopes, index);
}
GemstoneTypedefRef type_scope_get_type(TypeScopeRef scope, size_t index) {

View File

@ -2,6 +2,7 @@
#ifndef LLVM_TYPE_SCOPE_H_
#define LLVM_TYPE_SCOPE_H_
#include <llvm-c/Types.h>
#include <llvm/function/function-types.h>
#include <glib.h>
#include <llvm/types/structs.h>
@ -17,6 +18,7 @@ typedef struct GemstoneDecl_t {
const char* name;
StorageQualifier storageQualifier;
GemstoneTypeRef type;
LLVMValueRef llvm_value;
} GemstoneDecl;
typedef GemstoneDecl* GemstoneDeclRef;
@ -49,7 +51,16 @@ void type_scope_append_type(TypeScopeRef scope, GemstoneTypedefRef type);
* @param child_scope
*/
[[gnu::nonnull(1), gnu::nonnull(2)]]
void type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope);
size_t type_scope_append_scope(TypeScopeRef scope, TypeScopeRef child_scope);
/**
* @brief Remove a new child scope to this scope
*
* @param scope
* @param child_scope
*/
[[gnu::nonnull(1)]]
void type_scope_remove_scope(TypeScopeRef scope, size_t index);
/**
* @brief Get the type at the specified index in this scope level
@ -116,4 +127,6 @@ GemstoneFunRef type_scope_get_fun_from_name(TypeScopeRef scope, const char* name
void type_scope_add_variable(TypeScopeRef scope, GemstoneDeclRef decl);
GemstoneDeclRef type_scope_get_variable(TypeScopeRef scope, const char *name);
#endif // LLVM_TYPE_SCOPE_H_