Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(ir_SOURCES
LocalGraph.cpp
LocalStructuralDominance.cpp
ReFinalize.cpp
return-utils.cpp
stack-utils.cpp
table-utils.cpp
type-updating.cpp
Expand Down
99 changes: 99 additions & 0 deletions src/ir/return-utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ir/return-utils.h"
#include "ir/module-utils.h"
#include "wasm-builder.h"
#include "wasm-traversal.h"
#include "wasm.h"

namespace wasm::ReturnUtils {

namespace {

struct ReturnValueRemover : public PostWalker<ReturnValueRemover> {
void visitReturn(Return* curr) {
auto* value = curr->value;
assert(value);
curr->value = nullptr;
Builder builder(*getModule());
replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
}

void visitCall(Call* curr) { handleReturnCall(curr); }
void visitCallIndirect(CallIndirect* curr) { handleReturnCall(curr); }
void visitCallRef(CallRef* curr) { handleReturnCall(curr); }

template<typename T> void handleReturnCall(T* curr) {
if (curr->isReturn) {
Fatal() << "Cannot remove return_calls in ReturnValueRemover";
}
}

void visitFunction(Function* curr) {
if (curr->body->type.isConcrete()) {
curr->body = Builder(*getModule()).makeDrop(curr->body);
}
}
};

} // anonymous namespace

void removeReturns(Function* func, Module& wasm) {
ReturnValueRemover().walkFunctionInModule(func, &wasm);
}

std::unordered_map<Function*, bool> findReturnCallers(Module& wasm) {
ModuleUtils::ParallelFunctionAnalysis<bool> analysis(
wasm, [&](Function* func, bool& hasReturnCall) {
if (func->imported()) {
return;
}

struct Finder : PostWalker<Finder> {
bool hasReturnCall = false;

void visitCall(Call* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
void visitCallIndirect(CallIndirect* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
void visitCallRef(CallRef* curr) {
if (curr->isReturn) {
hasReturnCall = true;
}
}
} finder;

finder.walk(func->body);
hasReturnCall = finder.hasReturnCall;
});

// Convert to an unordered map for fast lookups. TODO: Avoid a copy here.
std::unordered_map<Function*, bool> ret;
ret.reserve(analysis.map.size());
for (auto& [k, v] : analysis.map) {
ret[k] = v;
}
return ret;
}

} // namespace wasm::ReturnUtils
39 changes: 39 additions & 0 deletions src/ir/return-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright 2024 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef wasm_ir_return_h
#define wasm_ir_return_h

#include "wasm.h"

namespace wasm::ReturnUtils {

// Removes values from both explicit returns and implicit ones (values that flow
// from the body). This is useful after changing a function's type to no longer
// return anything.
//
// This does *not* handle return calls, and will error on them. Removing a
// return call may change the semantics of the program, so we do not do it
// automatically here.
void removeReturns(Function* func, Module& wasm);

// Return a map of every function to whether it does a return call.
using ReturnCallersMap = std::unordered_map<Function*, bool>;
ReturnCallersMap findReturnCallers(Module& wasm);

} // namespace wasm::ReturnUtils

#endif // wasm_ir_return_h
19 changes: 2 additions & 17 deletions src/passes/DeadArgumentElimination.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include "ir/find_all.h"
#include "ir/lubs.h"
#include "ir/module-utils.h"
#include "ir/return-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "param-utils.h"
Expand Down Expand Up @@ -358,23 +359,7 @@ struct DAE : public Pass {
}
}
// Remove any return values.
struct ReturnUpdater : public PostWalker<ReturnUpdater> {
Module* module;
ReturnUpdater(Function* func, Module* module) : module(module) {
walk(func->body);
}
void visitReturn(Return* curr) {
auto* value = curr->value;
assert(value);
curr->value = nullptr;
Builder builder(*module);
replaceCurrent(builder.makeSequence(builder.makeDrop(value), curr));
}
} returnUpdater(func, module);
// Remove any value flowing out.
if (func->body->type.isConcrete()) {
func->body = Builder(*module).makeDrop(func->body);
}
ReturnUtils::removeReturns(func, *module);
}

// Given a function and all the calls to it, see if we can refine the type of
Expand Down
105 changes: 86 additions & 19 deletions src/passes/Monomorphize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@
#include "ir/manipulation.h"
#include "ir/module-utils.h"
#include "ir/names.h"
#include "ir/return-utils.h"
#include "ir/type-updating.h"
#include "ir/utils.h"
#include "pass.h"
Expand All @@ -103,6 +104,36 @@ namespace wasm {

namespace {

// Core information about a call: the call itself, and if it is dropped, the
// drop.
struct CallInfo {
Call* call;
// Store a reference to the drop's pointer so that we can replace it, as when
// we optimize a dropped call we need to replace (drop (call)) with (call).
// Or, if the call is not dropped, this is nullptr.
Expression** drop;
};

// Finds the calls and whether each one of them is dropped.
struct CallFinder : public PostWalker<CallFinder> {
std::vector<CallInfo> infos;

void visitCall(Call* curr) {
// Add the call as not having a drop, and update the drop later if we are.
infos.push_back(CallInfo{curr, nullptr});
}

void visitDrop(Drop* curr) {
if (curr->value->is<Call>()) {
// The call we just added to |infos| is dropped.
assert(!infos.empty());
auto& back = infos.back();
assert(back.call == curr->value);
back.drop = getCurrentPointer();
}
}
};

// Relevant information about a callsite for purposes of monomorphization.
struct CallContext {
// The operands of the call, processed to leave the parts that make sense to
Expand Down Expand Up @@ -181,12 +212,12 @@ struct CallContext {
// remaining values by updating |newOperands| (for example, if all the values
// sent are constants, then |newOperands| will end up empty, as we have
// nothing left to send).
void buildFromCall(Call* call,
void buildFromCall(CallInfo& info,
std::vector<Expression*>& newOperands,
Module& wasm) {
Builder builder(wasm);

for (auto* operand : call->operands) {
for (auto* operand : info.call->operands) {
// Process the operand. This is a copy operation, as we are trying to move
// (copy) code from the callsite into the called function. When we find we
// can copy then we do so, and when we cannot that value remains as a
Expand All @@ -212,8 +243,7 @@ struct CallContext {
}));
}

// TODO: handle drop
dropped = false;
dropped = !!info.drop;
}

// Checks whether an expression can be moved into the context.
Expand Down Expand Up @@ -299,6 +329,11 @@ struct Monomorphize : public Pass {
void run(Module* module) override {
// TODO: parallelize, see comments below

// Find all the return-calling functions. We cannot remove their returns
// (because turning a return call into a normal call may break the program
// by using more stack).
auto returnCallersMap = ReturnUtils::findReturnCallers(*module);

// Note the list of all functions. We'll be adding more, and do not want to
// operate on those.
std::vector<Name> funcNames;
Expand All @@ -309,26 +344,38 @@ struct Monomorphize : public Pass {
// to call the monomorphized targets.
for (auto name : funcNames) {
auto* func = module->getFunction(name);
for (auto* call : FindAll<Call>(func->body).list) {
if (call->type == Type::unreachable) {

CallFinder callFinder;
callFinder.walk(func->body);
for (auto& info : callFinder.infos) {
if (info.call->type == Type::unreachable) {
// Ignore unreachable code.
// TODO: return_call?
continue;
}

if (call->target == name) {
if (info.call->target == name) {
// Avoid recursion, which adds some complexity (as we'd be modifying
// ourselves if we apply optimizations).
continue;
}

processCall(call, *module);
// If the target function does a return call, then as noted earlier we
// cannot remove its returns, so do not consider the drop as part of the
// context in such cases (as if we reverse-inlined the drop into the
// target then we'd be removing the returns).
if (returnCallersMap[module->getFunction(info.call->target)]) {
info.drop = nullptr;
}

processCall(info, *module);
}
}
}

// Try to optimize a call.
void processCall(Call* call, Module& wasm) {
void processCall(CallInfo& info, Module& wasm) {
auto* call = info.call;
auto target = call->target;
auto* func = wasm.getFunction(target);
if (func->imported()) {
Expand All @@ -342,19 +389,16 @@ struct Monomorphize : public Pass {
// if we use that context.
CallContext context;
std::vector<Expression*> newOperands;
context.buildFromCall(call, newOperands, wasm);
context.buildFromCall(info, newOperands, wasm);

// See if we've already evaluated this function + call context. If so, then
// we've memoized the result.
auto iter = funcContextMap.find({target, context});
if (iter != funcContextMap.end()) {
auto newTarget = iter->second;
if (newTarget != target) {
// When we computed this before we found a benefit to optimizing, and
// created a new monomorphized function to call. Use it by simply
// applying the new operands we computed, and adjusting the call target.
call->operands.set(newOperands);
call->target = newTarget;
// We saw benefit to optimizing this case. Apply that.
updateCall(info, newTarget, newOperands, wasm);
}
return;
}
Expand Down Expand Up @@ -419,8 +463,7 @@ struct Monomorphize : public Pass {
if (worthwhile) {
// We are using the monomorphized function, so update the call and add it
// to the module.
call->operands.set(newOperands);
call->target = monoFunc->name;
updateCall(info, monoFunc->name, newOperands, wasm);

wasm.addFunction(std::move(monoFunc));
}
Expand Down Expand Up @@ -453,8 +496,9 @@ struct Monomorphize : public Pass {
newParams.push_back(operand->type);
}
}
// TODO: support changes to results.
auto newResults = func->getResults();
// If we were dropped then we are pulling the drop into the monomorphized
// function, which means we return nothing.
auto newResults = context.dropped ? Type::none : func->getResults();
newFunc->type = Signature(Type(newParams), newResults);

// We must update local indexes: the new function has a potentially
Expand Down Expand Up @@ -549,9 +593,32 @@ struct Monomorphize : public Pass {
newFunc->body = builder.makeBlock(pre);
}

if (context.dropped) {
ReturnUtils::removeReturns(newFunc.get(), wasm);
}

return newFunc;
}

// Given a call and a new target it should be calling, apply that new target,
// including updating the operands and handling dropping.
void updateCall(const CallInfo& info,
Name newTarget,
const std::vector<Expression*>& newOperands,
Module& wasm) {
info.call->target = newTarget;
info.call->operands.set(newOperands);

if (info.drop) {
// Replace (drop (call)) with (call), that is, replace the drop with the
// (updated) call which now has type none. Note we should have handled
// unreachability before getting here.
assert(info.call->type != Type::unreachable);
info.call->type = Type::none;
*info.drop = info.call;
}
}

// Run some function-level optimizations on a function. Ideally we would run a
// minimal amount of optimizations here, but we do want to give the optimizer
// as much of a chance to work as possible, so for now do all of -O3 (in
Expand Down
Loading