From b786b3e156be3881efa059381216f64515ca5ed1 Mon Sep 17 00:00:00 2001 From: servostar Date: Sun, 4 Aug 2024 16:20:23 +0200 Subject: [PATCH] added function return type --- src/ast/ast.c | 5 +- src/ast/ast.h | 5 +- src/llvm/llvm-ir/expr.c | 3 + src/llvm/llvm-ir/func.c | 133 ++++++++++++++++++++++++-- src/llvm/llvm-ir/func.h | 5 + src/llvm/llvm-ir/stmt.c | 108 +++++++-------------- src/set/set.c | 184 ++++++++++++++++++++++++++++++------ src/set/types.c | 28 ++++++ src/set/types.h | 20 +++- src/yacc/parser.y | 54 ++++++----- tests/stdlib/src/matrix.gsc | 11 +-- 11 files changed, 412 insertions(+), 144 deletions(-) create mode 100644 src/set/types.c diff --git a/src/ast/ast.c b/src/ast/ast.c index 582c3d4..d11b9c3 100644 --- a/src/ast/ast.c +++ b/src/ast/ast.c @@ -67,7 +67,10 @@ void AST_init() { lookup_table[AST_Typedef] = "typedef"; lookup_table[AST_Box] = "box"; - lookup_table[AST_Fun] = "fun"; + lookup_table[AST_FunDecl] = "fun"; + lookup_table[AST_FunDef] = "fun"; + lookup_table[AST_ProcDecl] = "fun"; + lookup_table[AST_ProcDef] = "fun"; lookup_table[AST_Call] = "funcall"; lookup_table[AST_Typecast] = "typecast"; diff --git a/src/ast/ast.h b/src/ast/ast.h index 4ead853..256eb96 100644 --- a/src/ast/ast.h +++ b/src/ast/ast.h @@ -57,7 +57,10 @@ enum AST_SyntaxElement_t { // Defintions AST_Typedef, AST_Box, - AST_Fun, + AST_FunDecl, + AST_FunDef, + AST_ProcDecl, + AST_ProcDef, AST_Import, // amount of variants // in this enums diff --git a/src/llvm/llvm-ir/expr.c b/src/llvm/llvm-ir/expr.c index d4129ea..eef1cef 100644 --- a/src/llvm/llvm-ir/expr.c +++ b/src/llvm/llvm-ir/expr.c @@ -524,6 +524,9 @@ BackendError impl_expr(LLVMBackendCompileUnit *unit, LLVMLocalScope *scope, deref_depth, llvm_result); break; + case ExpressionKindFunctionCall: + err = impl_func_call(unit, builder, scope, expr->impl.call, llvm_result); + break; default: err = new_backend_impl_error(Implementation, NULL, "unknown expression"); break; diff --git a/src/llvm/llvm-ir/func.c b/src/llvm/llvm-ir/func.c index 032aa3b..09eb2b5 100644 --- a/src/llvm/llvm-ir/func.c +++ b/src/llvm/llvm-ir/func.c @@ -10,6 +10,7 @@ #include #include #include +#include LLVMLocalScope* new_local_scope(LLVMLocalScope* parent) { LLVMLocalScope* scope = malloc(sizeof(LLVMLocalScope)); @@ -127,8 +128,19 @@ BackendError impl_func_type(LLVMBackendCompileUnit* unit, DEBUG("implemented %ld parameter", llvm_params->len); + LLVMTypeRef llvm_return_type = LLVMVoidTypeInContext(unit->context); + if (func->kind == FunctionDeclarationKind) { + if (func->impl.declaration.return_value != NULL) { + err = get_type_impl(unit, scope, func->impl.declaration.return_value, &llvm_return_type); + } + } else { + if (func->impl.definition.return_value != NULL) { + err = get_type_impl(unit, scope, func->impl.definition.return_value, &llvm_return_type); + } + } + LLVMTypeRef llvm_fun_type = - LLVMFunctionType(LLVMVoidTypeInContext(unit->context), + LLVMFunctionType(llvm_return_type, (LLVMTypeRef*)llvm_params->data, llvm_params->len, 0); *llvm_fun = LLVMAddFunction(unit->module, func->name, llvm_fun_type); @@ -180,14 +192,39 @@ BackendError impl_func_def(LLVMBackendCompileUnit* unit, LLVMPositionBuilderAtEnd(builder, entry); LLVMBuildBr(builder, llvm_start_body_block); - // insert returning end block - LLVMBasicBlockRef end_block = - LLVMAppendBasicBlockInContext(unit->context, llvm_func, "func.end"); - LLVMPositionBuilderAtEnd(builder, end_block); - LLVMBuildRetVoid(builder); + LLVMValueRef terminator = LLVMGetBasicBlockTerminator(llvm_end_body_block); + if (terminator == NULL) { + // insert returning end block + LLVMBasicBlockRef end_block = + LLVMAppendBasicBlockInContext(unit->context, llvm_func, "func.end"); + LLVMPositionBuilderAtEnd(builder, end_block); - LLVMPositionBuilderAtEnd(builder, llvm_end_body_block); - LLVMBuildBr(builder, end_block); + LLVMValueRef llvm_return = NULL; + if (func->kind == FunctionDeclarationKind) { + if (func->impl.declaration.return_value != NULL) { + err = get_type_default_value(unit, global_scope, func->impl.declaration.return_value, &llvm_return); + if (err.kind != Success) { + return err; + } + LLVMBuildRet(builder, llvm_return); + }else { + LLVMBuildRetVoid(builder); + } + } else { + if (func->impl.definition.return_value != NULL) { + err = get_type_default_value(unit, global_scope, func->impl.definition.return_value, &llvm_return); + if (err.kind != Success) { + return err; + } + LLVMBuildRet(builder, llvm_return); + } else { + LLVMBuildRetVoid(builder); + } + } + + LLVMPositionBuilderAtEnd(builder, llvm_end_body_block); + LLVMBuildBr(builder, end_block); + } LLVMDisposeBuilder(builder); } @@ -247,3 +284,83 @@ BackendError impl_functions(LLVMBackendCompileUnit* unit, return err; } + +gboolean is_parameter_out(Parameter *param) { + gboolean is_out = FALSE; + + if (param->kind == ParameterDeclarationKind) { + is_out = param->impl.declaration.qualifier == Out || param->impl.declaration.qualifier == InOut; + } else { + is_out = param->impl.definiton.declaration.qualifier == Out || + param->impl.definiton.declaration.qualifier == InOut; + } + + return is_out; +} + +BackendError impl_func_call(LLVMBackendCompileUnit *unit, + LLVMBuilderRef builder, LLVMLocalScope *scope, + const FunctionCall *call, + LLVMValueRef* return_value) { + DEBUG("implementing function call..."); + BackendError err = SUCCESS; + + LLVMValueRef* arguments = NULL; + + // prevent memory allocation when number of bytes would be zero + // avoid going of assertion in memory cache + if (call->expressions->len > 0) { + arguments = mem_alloc(MemoryNamespaceLlvm, sizeof(LLVMValueRef) * call->expressions->len); + + for (size_t i = 0; i < call->expressions->len; i++) { + Expression *arg = g_array_index(call->expressions, Expression*, i); + + GArray* param_list; + if (call->function->kind == FunctionDeclarationKind) { + param_list = call->function->impl.definition.parameter; + } else { + param_list = call->function->impl.declaration.parameter; + } + + Parameter param = g_array_index(param_list, Parameter, i); + + LLVMValueRef llvm_arg = NULL; + err = impl_expr(unit, scope, builder, arg, is_parameter_out(¶m), 0, &llvm_arg); + + if (err.kind != Success) { + break; + } + + if (is_parameter_out(¶m)) { + if ((arg->kind == ExpressionKindParameter && !is_parameter_out(arg->impl.parameter)) || arg->kind != ExpressionKindParameter) { + LLVMValueRef index = LLVMConstInt(LLVMInt32Type(), 0, false); + LLVMTypeRef llvm_type = NULL; + get_type_impl(unit, scope->func_scope->global_scope, param.impl.declaration.type, &llvm_type); + llvm_arg = LLVMBuildGEP2(builder, llvm_type, llvm_arg, &index, 1, ""); + } + } + + arguments[i] = llvm_arg; + } + } + + if (err.kind == Success) { + LLVMValueRef llvm_func = LLVMGetNamedFunction(unit->module, call->function->name); + + if (llvm_func == NULL) { + return new_backend_impl_error(Implementation, NULL, "no declared function"); + } + + LLVMTypeRef llvm_func_type = g_hash_table_lookup(scope->func_scope->global_scope->functions, call->function->name); + + LLVMValueRef value = LLVMBuildCall2(builder, llvm_func_type, llvm_func, arguments, call->expressions->len, + ""); + + if (NULL != return_value) { + *return_value = value; + } + } + + return err; +} + diff --git a/src/llvm/llvm-ir/func.h b/src/llvm/llvm-ir/func.h index 349135a..a3f47ac 100644 --- a/src/llvm/llvm-ir/func.h +++ b/src/llvm/llvm-ir/func.h @@ -40,4 +40,9 @@ BackendError impl_functions(LLVMBackendCompileUnit* unit, LLVMGlobalScope* scope, GHashTable* variables); +BackendError impl_func_call(LLVMBackendCompileUnit *unit, + LLVMBuilderRef builder, LLVMLocalScope *scope, + const FunctionCall *call, + LLVMValueRef* return_value); + #endif // LLVM_BACKEND_FUNC_H_ diff --git a/src/llvm/llvm-ir/stmt.c b/src/llvm/llvm-ir/stmt.c index 2bae20d..b8cf81f 100644 --- a/src/llvm/llvm-ir/stmt.c +++ b/src/llvm/llvm-ir/stmt.c @@ -141,6 +141,8 @@ BackendError impl_basic_block(LLVMBackendCompileUnit *unit, LLVMBasicBlockRef end_previous_block = *llvm_start_block; + bool terminated = false; + for (size_t i = 0; i < block->statemnts->len; i++) { DEBUG("building block statement %d of %d", i, block->statemnts->len); Statement* stmt = g_array_index(block->statemnts, Statement*, i); @@ -152,12 +154,20 @@ BackendError impl_basic_block(LLVMBackendCompileUnit *unit, return err; } - if (llvm_next_end_block != NULL) { + terminated = LLVMGetBasicBlockTerminator(end_previous_block); + if (llvm_next_end_block != NULL && !terminated) { LLVMPositionBuilderAtEnd(builder, end_previous_block); LLVMBuildBr(builder, llvm_next_start_block); + LLVMPositionBuilderAtEnd(builder, llvm_next_end_block); end_previous_block = llvm_next_end_block; } + + if (terminated) { + end_previous_block = LLVMAppendBasicBlockInContext(unit->context, scope->func_scope->llvm_func, + "ret.after"); + LLVMPositionBuilderAtEnd(builder, end_previous_block); + } } *llvm_end_block = end_previous_block; @@ -213,80 +223,6 @@ BackendError impl_while(LLVMBackendCompileUnit *unit, return err; } -gboolean is_parameter_out(Parameter *param) { - gboolean is_out = FALSE; - - if (param->kind == ParameterDeclarationKind) { - is_out = param->impl.declaration.qualifier == Out || param->impl.declaration.qualifier == InOut; - } else { - is_out = param->impl.definiton.declaration.qualifier == Out || - param->impl.definiton.declaration.qualifier == InOut; - } - - return is_out; -} - -BackendError impl_func_call(LLVMBackendCompileUnit *unit, - LLVMBuilderRef builder, LLVMLocalScope *scope, - const FunctionCall *call) { - DEBUG("implementing function call..."); - BackendError err = SUCCESS; - - LLVMValueRef* arguments = NULL; - - // prevent memory allocation when number of bytes would be zero - // avoid going of assertion in memory cache - if (call->expressions->len > 0) { - arguments = mem_alloc(MemoryNamespaceLlvm, sizeof(LLVMValueRef) * call->expressions->len); - - for (size_t i = 0; i < call->expressions->len; i++) { - Expression *arg = g_array_index(call->expressions, Expression*, i); - - GArray* param_list; - if (call->function->kind == FunctionDeclarationKind) { - param_list = call->function->impl.definition.parameter; - } else { - param_list = call->function->impl.declaration.parameter; - } - - Parameter param = g_array_index(param_list, Parameter, i); - - LLVMValueRef llvm_arg = NULL; - err = impl_expr(unit, scope, builder, arg, is_parameter_out(¶m), 0, &llvm_arg); - - if (err.kind != Success) { - break; - } - - if (is_parameter_out(¶m)) { - if ((arg->kind == ExpressionKindParameter && !is_parameter_out(arg->impl.parameter)) || arg->kind != ExpressionKindParameter) { - LLVMValueRef index = LLVMConstInt(LLVMInt32Type(), 0, false); - LLVMTypeRef llvm_type = NULL; - get_type_impl(unit, scope->func_scope->global_scope, param.impl.declaration.type, &llvm_type); - llvm_arg = LLVMBuildGEP2(builder, llvm_type, llvm_arg, &index, 1, ""); - } - } - - arguments[i] = llvm_arg; - } - } - - if (err.kind == Success) { - LLVMValueRef llvm_func = LLVMGetNamedFunction(unit->module, call->function->name); - - if (llvm_func == NULL) { - return new_backend_impl_error(Implementation, NULL, "no declared function"); - } - - LLVMTypeRef llvm_func_type = g_hash_table_lookup(scope->func_scope->global_scope->functions, call->function->name); - - LLVMBuildCall2(builder, llvm_func_type, llvm_func, arguments, call->expressions->len, - ""); - } - - return err; -} - BackendError impl_cond_block(LLVMBackendCompileUnit *unit, LLVMBuilderRef builder, LLVMLocalScope *scope, Expression *cond, const Block *block, LLVMBasicBlockRef *cond_block, LLVMBasicBlockRef *start_body_block, LLVMBasicBlockRef *end_body_block, @@ -432,6 +368,23 @@ BackendError impl_decl(LLVMBackendCompileUnit *unit, return err; } +BackendError impl_return(LLVMBackendCompileUnit *unit, + LLVMBuilderRef builder, + LLVMLocalScope *scope, + Return *returnStmt) { + BackendError err = SUCCESS; + + LLVMValueRef expr = NULL; + err = impl_expr(unit, scope, builder, returnStmt->value, false, 0, &expr); + if (err.kind != Success) { + return err; + } + + LLVMBuildRet(builder, expr); + + return err; +} + BackendError impl_def(LLVMBackendCompileUnit *unit, LLVMBuilderRef builder, LLVMLocalScope *scope, @@ -502,7 +455,10 @@ BackendError impl_stmt(LLVMBackendCompileUnit *unit, LLVMBuilderRef builder, LLV err = impl_while(unit, builder, scope, llvm_start_block, llvm_end_block, &stmt->impl.whileLoop); break; case StatementKindFunctionCall: - err = impl_func_call(unit, builder, scope, &stmt->impl.call); + err = impl_func_call(unit, builder, scope, &stmt->impl.call, NULL); + break; + case StatementKindReturn: + err = impl_return(unit, builder, scope, &stmt->impl.returnStmt); break; default: err = new_backend_impl_error(Implementation, NULL, "Unexpected statement kind"); diff --git a/src/set/set.c b/src/set/set.c index a904bad..9c16aa2 100644 --- a/src/set/set.c +++ b/src/set/set.c @@ -1463,6 +1463,8 @@ IO_Qualifier getParameterQualifier(Parameter *parameter) { } } +int createfuncall(FunctionCall* funcall, AST_NODE_PTR currentNode); + Expression *createExpression(AST_NODE_PTR currentNode) { DEBUG("create Expression"); Expression *expression = mem_alloc(MemoryNamespaceSet, sizeof(Expression)); @@ -1595,6 +1597,14 @@ Expression *createExpression(AST_NODE_PTR currentNode) { return NULL; } break; + case AST_Call: + expression->kind = ExpressionKindFunctionCall; + expression->impl.call = mem_alloc(MemoryNamespaceSet, sizeof(FunctionCall)); + if (createfuncall(expression->impl.call, currentNode) == SEMANTIC_ERROR) { + return NULL; + } + expression->result = SET_function_get_return_type(expression->impl.call->function); + break; default: PANIC("Node is not an expression but from kind: %i", currentNode->kind); break; @@ -1901,14 +1911,13 @@ Parameter get_param_from_func(Function* func, size_t index) { } } -int createfuncall(Statement *parentStatement, AST_NODE_PTR currentNode) { +int createfuncall(FunctionCall* funcall, AST_NODE_PTR currentNode) { assert(currentNode != NULL); assert(currentNode->children->len == 2); AST_NODE_PTR argsListNode = AST_get_node(currentNode, 1); AST_NODE_PTR nameNode = AST_get_node(currentNode, 0); - FunctionCall funcall; Function *fun = NULL; if (nameNode->kind == AST_Ident) { int result = getFunction(nameNode->value, &fun); @@ -1934,8 +1943,8 @@ int createfuncall(Statement *parentStatement, AST_NODE_PTR currentNode) { } } - funcall.function = fun; - funcall.nodePtr = currentNode; + funcall->function = fun; + funcall->nodePtr = currentNode; size_t paramCount = 0; if (fun->kind == FunctionDeclarationKind) { @@ -1980,10 +1989,7 @@ int createfuncall(Statement *parentStatement, AST_NODE_PTR currentNode) { g_array_append_val(expressions, expr); } } - funcall.expressions = expressions; - - parentStatement->kind = StatementKindFunctionCall; - parentStatement->impl.call = funcall; + funcall->expressions = expressions; return SEMANTIC_OK; } @@ -2059,18 +2065,37 @@ int createStatement(Block *Parentblock, AST_NODE_PTR currentNode) { g_array_append_val(Parentblock->statemnts, statement); } break; - case AST_Call: + case AST_Call: { Statement *statement = mem_alloc(MemoryNamespaceSet, sizeof(Statement)); statement->nodePtr = currentNode; statement->kind = StatementKindFunctionCall; - int result = createfuncall(statement, currentNode); + int result = createfuncall(&statement->impl.call, currentNode); if (result == SEMANTIC_ERROR) { return SEMANTIC_ERROR; } g_array_append_val(Parentblock->statemnts, statement); break; + } + case AST_Return: { + Statement *statement = mem_alloc(MemoryNamespaceSet, sizeof(Statement)); + statement->nodePtr = currentNode; + statement->kind = StatementKindReturn; + + AST_NODE_PTR expr_node = AST_get_node(currentNode, 0); + statement->impl.returnStmt.value = createExpression(expr_node); + statement->impl.returnStmt.nodePtr = currentNode; + + if (statement->impl.returnStmt.value == NULL) { + return SEMANTIC_ERROR; + } + + // TODO: compare result and function return type + + g_array_append_val(Parentblock->statemnts, statement); + break; + } default: - PANIC("Node is not a statement"); + PANIC("Node is not a statement: %s", AST_node_to_string(currentNode)); break; } @@ -2129,8 +2154,9 @@ int createParam(GArray *Paramlist, AST_NODE_PTR currentNode) { int createFunDef(Function *Parentfunction, AST_NODE_PTR currentNode) { DEBUG("start fundef"); AST_NODE_PTR nameNode = AST_get_node(currentNode, 0); - AST_NODE_PTR paramlistlist = AST_get_node(currentNode, 1); - AST_NODE_PTR statementlist = AST_get_node(currentNode, 2); + AST_NODE_PTR return_value_node = AST_get_node(currentNode, 1); + AST_NODE_PTR paramlistlist = AST_get_node(currentNode, 2); + AST_NODE_PTR statementlist = AST_get_node(currentNode, 3); FunctionDefinition fundef; @@ -2138,6 +2164,12 @@ int createFunDef(Function *Parentfunction, AST_NODE_PTR currentNode) { fundef.name = nameNode->value; fundef.body = mem_alloc(MemoryNamespaceSet, sizeof(Block)); fundef.parameter = mem_new_g_array(MemoryNamespaceSet, sizeof(Parameter)); + fundef.return_value = NULL; + + if (set_get_type_impl(return_value_node, &fundef.return_value) == SEMANTIC_ERROR) { + print_diagnostic(&return_value_node->location, Error, "Unknown return value type"); + return SEMANTIC_ERROR; + } DEBUG("paramlistlist child count: %i", paramlistlist->children->len); for (size_t i = 0; i < paramlistlist->children->len; i++) { @@ -2237,7 +2269,49 @@ int getFunction(const char *name, Function **function) { return SEMANTIC_ERROR; } -int createFunDecl(Function *Parentfunction, AST_NODE_PTR currentNode) { +int createProcDef(Function *Parentfunction, AST_NODE_PTR currentNode) { + DEBUG("start fundef"); + AST_NODE_PTR nameNode = AST_get_node(currentNode, 0); + AST_NODE_PTR paramlistlist = AST_get_node(currentNode, 1); + AST_NODE_PTR statementlist = AST_get_node(currentNode, 2); + + FunctionDefinition fundef; + + fundef.nodePtr = currentNode; + fundef.name = nameNode->value; + fundef.body = mem_alloc(MemoryNamespaceSet, sizeof(Block)); + fundef.parameter = mem_new_g_array(MemoryNamespaceSet, sizeof(Parameter)); + fundef.return_value = NULL; + + DEBUG("paramlistlist child count: %i", paramlistlist->children->len); + for (size_t i = 0; i < paramlistlist->children->len; i++) { + + //all parameterlists + AST_NODE_PTR paramlist = AST_get_node(paramlistlist, i); + DEBUG("paramlist child count: %i", paramlist->children->len); + for (int j = ((int) paramlist->children->len) - 1; j >= 0; j--) { + + DEBUG("param child count: %i", AST_get_node(paramlist, j)->children->len); + + if (createParam(fundef.parameter, AST_get_node(paramlist, j))) { + return SEMANTIC_ERROR; + } + } + DEBUG("End of Paramlist"); + } + + if (fillBlock(fundef.body, statementlist)) { + return SEMANTIC_ERROR; + } + + Parentfunction->nodePtr = currentNode; + Parentfunction->kind = FunctionDefinitionKind; + Parentfunction->impl.definition = fundef; + Parentfunction->name = fundef.name; + return SEMANTIC_OK; +} + +int createProcDecl(Function *Parentfunction, AST_NODE_PTR currentNode) { DEBUG("start fundecl"); AST_NODE_PTR nameNode = AST_get_node(currentNode, 0); AST_NODE_PTR paramlistlist = AST_get_node(currentNode, 1); @@ -2247,6 +2321,46 @@ int createFunDecl(Function *Parentfunction, AST_NODE_PTR currentNode) { fundecl.nodePtr = currentNode; fundecl.name = nameNode->value; fundecl.parameter = mem_new_g_array(MemoryNamespaceSet, sizeof(Parameter)); + fundecl.return_value = NULL; + + for (size_t i = 0; i < paramlistlist->children->len; i++) { + + //all parameter lists + AST_NODE_PTR paramlist = AST_get_node(paramlistlist, i); + + for (int j = ((int) paramlist->children->len) - 1; j >= 0; j--) { + AST_NODE_PTR param = AST_get_node(paramlist, j); + if (createParam(fundecl.parameter, param)) { + return SEMANTIC_ERROR; + } + } + } + + Parentfunction->nodePtr = currentNode; + Parentfunction->kind = FunctionDeclarationKind; + Parentfunction->impl.declaration = fundecl; + Parentfunction->name = fundecl.name; + + return SEMANTIC_OK; +} + +int createFunDecl(Function *Parentfunction, AST_NODE_PTR currentNode) { + DEBUG("start fundecl"); + AST_NODE_PTR nameNode = AST_get_node(currentNode, 0); + AST_NODE_PTR return_value_node = AST_get_node(currentNode, 1); + AST_NODE_PTR paramlistlist = AST_get_node(currentNode, 2); + + FunctionDeclaration fundecl; + + fundecl.nodePtr = currentNode; + fundecl.name = nameNode->value; + fundecl.parameter = mem_new_g_array(MemoryNamespaceSet, sizeof(Parameter)); + fundecl.return_value = NULL; + + if (set_get_type_impl(return_value_node, &fundecl.return_value) == SEMANTIC_ERROR) { + print_diagnostic(&return_value_node->location, Error, "Unknown return value type"); + return SEMANTIC_ERROR; + } for (size_t i = 0; i < paramlistlist->children->len; i++) { @@ -2270,21 +2384,32 @@ int createFunDecl(Function *Parentfunction, AST_NODE_PTR currentNode) { } int createFunction(Function *function, AST_NODE_PTR currentNode) { - assert(currentNode->kind == AST_Fun); functionParameter = mem_new_g_hash_table(MemoryNamespaceSet, g_str_hash, g_str_equal); - if (currentNode->children->len == 2) { - int signal = createFunDecl(function, currentNode); - if (signal) { + switch (currentNode->kind) { + case AST_FunDecl: + if (createFunDecl(function, currentNode)) { + return SEMANTIC_ERROR; + } + break; + case AST_FunDef: + if (createFunDef(function, currentNode)) { + return SEMANTIC_ERROR; + } + break; + case AST_ProcDecl: + if (createProcDecl(function, currentNode)) { + return SEMANTIC_ERROR; + } + break; + case AST_ProcDef: + if (createProcDef(function, currentNode)) { + return SEMANTIC_ERROR; + } + break; + default: + ERROR("invalid AST node type: %s", AST_node_to_string(currentNode)); return SEMANTIC_ERROR; - } - } else if (currentNode->children->len == 3) { - int signal = createFunDef(function, currentNode); - if (signal) { - return SEMANTIC_ERROR; - } - } else { - PANIC("function should have 2 or 3 children"); } mem_free(functionParameter); @@ -2400,7 +2525,7 @@ int createBox(GHashTable *boxes, AST_NODE_PTR currentNode) { return SEMANTIC_ERROR; } break; - case AST_Fun: { + case AST_FunDef: { int result = createBoxFunction(boxName, boxType, AST_get_node(boxMemberList, i)); if (result == SEMANTIC_ERROR) { return SEMANTIC_ERROR; @@ -2523,7 +2648,10 @@ Module *create_set(AST_NODE_PTR currentNode) { DEBUG("created Box successfully"); break; } - case AST_Fun: { + case AST_FunDef: + case AST_FunDecl: + case AST_ProcDef: + case AST_ProcDecl: { DEBUG("start function"); Function *function = mem_alloc(MemoryNamespaceSet, sizeof(Function)); diff --git a/src/set/types.c b/src/set/types.c new file mode 100644 index 0000000..8b07cec --- /dev/null +++ b/src/set/types.c @@ -0,0 +1,28 @@ +// +// Created by servostar on 8/4/24. +// + +#include +#include +#include + +Type* SET_function_get_return_type(Function* function) { + assert(NULL != function); + + const Type* return_type = NULL; + + switch (function->kind) { + case FunctionDeclarationKind: + return_type = function->impl.declaration.return_value; + break; + case FunctionDefinitionKind: + return_type = function->impl.definition.return_value; + break; + default: + PANIC("invalid function kind: %d", function->kind); + } + + if (NULL == return_type) { + ERROR("Function return type is nullptr"); + } +} diff --git a/src/set/types.h b/src/set/types.h index 501b053..d4f2169 100644 --- a/src/set/types.h +++ b/src/set/types.h @@ -213,6 +213,7 @@ typedef struct FunctionDefinition_t { // hashtable of parameters // associates a parameters name (const char*) with its parameter declaration (ParameterDeclaration) GArray* parameter; // Parameter + Type* return_value; AST_NODE_PTR nodePtr; // body of function Block *body; @@ -225,6 +226,7 @@ typedef struct FunctionDeclaration_t { // associates a parameters name (const char*) with its parameter declaration (ParameterDeclaration) GArray* parameter; // Parameter AST_NODE_PTR nodePtr; + Type* return_value; const char* name; } FunctionDeclaration; @@ -439,8 +441,11 @@ typedef enum ExpressionKind_t { ExpressionKindParameter, ExpressionKindDereference, ExpressionKindAddressOf, + ExpressionKindFunctionCall, } ExpressionKind; +typedef struct FunctionCall_t FunctionCall; + typedef struct Expression_t { ExpressionKind kind; // type of resulting data @@ -454,6 +459,7 @@ typedef struct Expression_t { Parameter* parameter; Dereference dereference; AddressOf addressOf; + FunctionCall* call; } impl; AST_NODE_PTR nodePtr; } Expression; @@ -554,6 +560,11 @@ typedef struct Assignment_t { AST_NODE_PTR nodePtr; } Assignment; +typedef struct Return_t { + Expression* value; + AST_NODE_PTR nodePtr; +} Return; + typedef enum StatementKind_t { StatementKindFunctionCall, StatementKindFunctionBoxCall, @@ -561,7 +572,8 @@ typedef enum StatementKind_t { StatementKindBranch, StatementKindAssignment, StatementKindDeclaration, - StatementKindDefinition + StatementKindDefinition, + StatementKindReturn } StatementKind; typedef struct Statement_t { @@ -573,6 +585,7 @@ typedef struct Statement_t { Branch branch; Assignment assignment; Variable *variable; + Return returnStmt; } impl; AST_NODE_PTR nodePtr; } Statement; @@ -591,6 +604,11 @@ typedef struct Module_t { GArray* includes; } Module; +// .------------------------------------------------. +// | Utility | +// '------------------------------------------------' + +Type* SET_function_get_return_type(Function* function); // .------------------------------------------------. // | Cleanup Code | diff --git a/src/yacc/parser.y b/src/yacc/parser.y index d1468fd..2b43040 100644 --- a/src/yacc/parser.y +++ b/src/yacc/parser.y @@ -58,6 +58,8 @@ %type programbody %type fundef %type fundecl +%type procdecl +%type procdef %type box %type typedef %type exprlist @@ -72,7 +74,7 @@ %type reinterpretcast %type program %type storage_expr -%type return +%type returnstmt %token KeyInt @@ -151,6 +153,8 @@ programbody: moduleimport {$$ = $1;} | moduleinclude {$$ = $1;} | fundef{$$ = $1;} | fundecl{$$ = $1;} + | procdecl{$$ = $1;} + | procdef{$$ = $1;} | box{$$ = $1;} | definition{$$ = $1;} | decl{$$ = $1;} @@ -170,6 +174,7 @@ expr: ValFloat {$$ = AST_new_node(new_loc(), AST_Float, $1);} | typecast{$$ = $1;} | reinterpretcast{$$ = $1;} | '(' expr ')' {$$=$2;} + | funcall {$$=$1;} | KeyRef Ident {AST_NODE_PTR addrof = AST_new_node(new_loc(), AST_AddressOf, NULL); AST_push_node(addrof, AST_new_node(new_loc(), AST_Ident, $2)); $$ = addrof;} @@ -194,35 +199,39 @@ argumentlist: argumentlist '(' exprlist ')' {AST_push_node($1, $3); $$ = list;}; -// TODO: add ast node for definition and declaration -fundef: KeyFun Ident paramlist '{' statementlist'}' {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_Fun, NULL); +procdef: KeyFun Ident paramlist '{' statementlist'}' {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_ProcDef, NULL); AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $2); AST_push_node(fun, ident); AST_push_node(fun, $3); AST_push_node(fun, $5); $$ = fun; DEBUG("Function");} - | KeyFun Ident paramlist ':' type '{' statementlist'}' {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_Fun, NULL); - AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $2); - AST_push_node(fun, ident); - AST_push_node(fun, $3); - AST_push_node(fun, $5); - AST_push_node(fun, $7); - $$ = fun; - DEBUG("Function");}; -fundecl: KeyFun Ident paramlist {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_Fun, NULL); +procdecl: KeyFun Ident paramlist {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_ProcDecl, NULL); AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $2); AST_push_node(fun, ident); AST_push_node(fun, $3); $$ = fun; - DEBUG("Function");} - | KeyFun Ident paramlist ':' type {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_Fun, NULL); - AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $2); - AST_push_node(fun, ident); - AST_push_node(fun, $3); - $$ = fun; - DEBUG("Function");}; + DEBUG("Function");}; + +fundef: KeyFun type ':' Ident paramlist '{' statementlist'}' {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_FunDef, NULL); + AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $4); + AST_push_node(fun, ident); + AST_push_node(fun, $2); + AST_push_node(fun, $5); + AST_push_node(fun, $7); + $$ = fun; + DEBUG("Function");} + +fundecl: KeyFun type ':' Ident paramlist {AST_NODE_PTR fun = AST_new_node(new_loc(), AST_FunDecl, NULL); + AST_NODE_PTR ident = AST_new_node(new_loc(), AST_Ident, $4); + AST_push_node(fun, ident); + AST_push_node(fun, $2); + AST_push_node(fun, $5); + $$ = fun; + DEBUG("Function");}; + + paramlist: paramlist '(' params ')' {AST_push_node($1, $3); $$ = $1;} @@ -365,12 +374,13 @@ statement: assign {$$ = $1;} | definition {$$ = $1;} | while {$$ = $1;} | branchfull {$$ = $1;} - | return {$$ = $1;} + | returnstmt {$$ = $1;} | funcall {$$ = $1;} | boxcall{$$ = $1;}; -return: KeyReturn expr { AST_NODE_PTR return_stmt = AST_new_node(new_loc(), AST_Return, NULL); - AST_push_node(return_stmt, $2); }; +returnstmt: KeyReturn expr { AST_NODE_PTR return_stmt = AST_new_node(new_loc(), AST_Return, NULL); + AST_push_node(return_stmt, $2); + $$ = return_stmt; }; branchif: KeyIf expr '{' statementlist '}' { AST_NODE_PTR branch = AST_new_node(new_loc(), AST_If, NULL); AST_push_node(branch, $2); diff --git a/tests/stdlib/src/matrix.gsc b/tests/stdlib/src/matrix.gsc index 234629f..21de436 100644 --- a/tests/stdlib/src/matrix.gsc +++ b/tests/stdlib/src/matrix.gsc @@ -1,7 +1,7 @@ import "std" -fun ulog10(in u32: num, out u32: log) +fun u32:ulog10(in u32: num) { u32: base = 1 as u32 u32: count = 0 as u32 @@ -16,13 +16,12 @@ fun ulog10(in u32: num, out u32: log) count = 1 as u32 } - log = count + ret count } fun u32ToCstr(in u32: number)(out cstr: result, out u32: len) { - u32: bytes = 0 as u32 - ulog10(number, bytes) + u32: bytes = ulog10(number) cstr: buf = 0 as cstr heapAlloc(bytes)(buf as ref u8) @@ -92,10 +91,8 @@ fun test_matrix() heapFree(matrix as ref u8) } -fun main(): u32 +fun i32:main() { test_matrix() - exit_code = 0 - ret 0 }