Files
llvm-project/llvm/lib/Transforms/IPO/HostRPC.cpp
Nicolas Marie 3b1aae9380 Use GPUFirst with libc rpc
- add missing headers in rpc.h.def
- add an opcode in libc rpc to handle gpu first host functions calls
- Fixe pointer casting
- Fixe Generated function to account for AMD address space
- remove LibC duplicate FILE declarations
- remove global variable to allow asyncronize rpc call
2024-06-18 08:42:58 -07:00

962 lines
30 KiB
C++

//===- Transform/IPO/HostRPC.h - Code of automatic host rpc -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/IPO/HostRPC.h"
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/Analysis/ConstantFolding.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Constant.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstrTypes.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetOptions.h"
#include "llvm/Transforms/IPO/Attributor.h"
#include <cstdint>
#define DEBUG_TYPE "host-rpc"
using namespace llvm;
using ArgType = llvm::omp::OMPTgtHostRPCArgType;
static cl::opt<bool>
UseDummyHostModule("host-rpc-use-dummy-host-module", cl::init(false),
cl::Hidden,
cl::desc("Use dummy host module if there no host module "
"attached to the device module"));
namespace {
enum class HostRPCRuntimeFunction {
#define __OMPRTL_HOST_RPC(_ENUM) OMPRTL_##_ENUM
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_invoke_host_wrapper),
__OMPRTL_HOST_RPC(__last),
#undef __OMPRTL_HOST_RPC
};
#define __OMPRTL_HOST_RPC(_ENUM) \
auto OMPRTL_##_ENUM = HostRPCRuntimeFunction::OMPRTL_##_ENUM;
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_invoke_host_wrapper)
#undef __OMPRTL_HOST_RPC
// TODO: Remove those functions implemented in device runtime.
static constexpr const char *InternalPrefix[] = {
"__kmp", "llvm.", "nvm.",
"omp_", "vprintf", "malloc",
"free", "__keep_alive", "__llvm_omp_vprintf",
"rpc_"
};
bool isInternalFunction(Function &F) {
auto Name = F.getName();
for (auto *P : InternalPrefix)
if (Name.starts_with(P))
return true;
return false;
}
std::string typeToString(Type *T) {
if (T->is16bitFPTy())
return "f16";
if (T->isFloatTy())
return "f32";
if (T->isDoubleTy())
return "f64";
if (T->isPointerTy())
return "ptr";
if (T->isStructTy())
return std::string(T->getStructName());
if (T->isIntegerTy())
return "i" + std::to_string(T->getIntegerBitWidth());
LLVM_DEBUG(dbgs() << "[HostRPC] unknown type " << *T
<< " for typeToString.\n";);
llvm_unreachable("unknown type");
}
class HostRPC {
/// LLVM context instance
LLVMContext &Context;
/// Device module.
Module &M;
/// Host module
Module &HM;
/// Data layout of the device module.
DataLayout DL;
IRBuilder<> Builder;
/// External functions we are operating on.
SmallSetVector<Function *, 8> FunctionWorkList;
/// Attributor instance.
Attributor &A;
// Types
Type *Int8PtrTy;
Type *VoidTy;
Type *Int32Ty;
Type *Int64Ty;
StructType *ArgInfoTy;
// Values
Constant *NullPtr;
Constant *NullInt64;
struct CallSiteInfo {
CallInst *CI = nullptr;
SmallVector<Type *> Params;
};
struct HostRPCArgInfo {
Value *BasePtr = nullptr;
Constant *Type = nullptr;
Value *Size = nullptr;
};
///
SmallVector<Function *> HostEntryTable;
EnumeratedArray<Function *, HostRPCRuntimeFunction,
HostRPCRuntimeFunction::OMPRTL___last> RFIs;
SmallVector<std::pair<CallInst *, CallInst *>> CallInstMap;
Constant *getConstantInt64(uint64_t Val) {
return ConstantInt::get(Int64Ty, Val);
}
static std::string getWrapperFunctionName(Function *F, CallSiteInfo &CSI) {
std::string Name = "__kmpc_host_rpc_wrapper_" + std::string(F->getName());
if (!F->isVarArg())
return Name;
for (unsigned I = F->getFunctionType()->getNumParams();
I < CSI.Params.size(); ++I) {
Name.push_back('_');
Name.append(typeToString(CSI.Params[I]));
}
return Name;
}
void registerAAs();
Value *convertToInt64Ty(Value *V);
Value *convertFromInt64TyTo(Value *V, Type *TargetTy);
Constant *convertToInt64Ty(Constant *C);
Constant *convertFromInt64TyTo(Constant *C, Type *T);
// int device_wrapper(call_no, arg_info, ...) {
// void *desc = __kmpc_host_rpc_get_desc(call_no, num_args, arg_info);
// __kmpc_host_rpc_add_arg(desc, arg1, sizeof(arg1));
// __kmpc_host_rpc_add_arg(desc, arg2, sizeof(arg2));
// ...
// int r = (int)__kmpc_host_rpc_send_and_wait(desc);
// return r;
// }
Function *getDeviceWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
// void host_wrapper(desc) {
// int arg1 = (int)__kmpc_host_rpc_get_arg(desc, 0);
// float arg2 = (float)__kmpc_host_rpc_get_arg(desc, 1);
// char *arg3 = (char *)__kmpc_host_rpc_get_arg(desc, 2);
// ...
// int r = actual_call(arg1, arg2, arg3, ...);
// __kmpc_host_rpc_set_ret_val(ptr(desc, (int64_t)r);
// }
Function *getHostWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
bool rewriteWithHostRPC(Function *F);
void emitHostWrapperInvoker();
bool recollectInformation();
public:
HostRPC(Module &DeviceModule, Module &HostModule, Attributor &A)
: Context(DeviceModule.getContext()), M(DeviceModule), HM(HostModule),
DL(M.getDataLayout()), Builder(Context), A(A) {
assert(&M.getContext() == &HM.getContext() &&
"device and host modules have different context");
Int8PtrTy = PointerType::getUnqual(Context);
VoidTy = Type::getVoidTy(Context);
Int32Ty = Type::getInt32Ty(Context);
Int64Ty = Type::getInt64Ty(Context);
NullPtr = ConstantInt::getNullValue(Int8PtrTy);
NullInt64 = ConstantInt::getNullValue(Int64Ty);
#define __OMP_RTL(_ENUM, MOD, VARARG, RETTY, ...) \
{ \
SmallVector<Type *> Params{__VA_ARGS__}; \
Function *F = (MOD).getFunction(#_ENUM); \
if (!F) { \
FunctionType *FT = FunctionType::get(RETTY, Params, VARARG); \
F = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage, \
#_ENUM, (MOD)); \
} \
RFIs[OMPRTL_##_ENUM] = F; \
}
// devices functions:
// get information about the functions that we are calling
__OMP_RTL(__kmpc_host_rpc_get_desc, M, false, Int8PtrTy, Int32Ty, Int32Ty,
Int8PtrTy)
// get arguments information about one of the argument
__OMP_RTL(__kmpc_host_rpc_add_arg, M, false, VoidTy, Int8PtrTy, Int64Ty,
Int32Ty)
// send the function to the host the function
__OMP_RTL(__kmpc_host_rpc_send_and_wait, M, false, Int64Ty, Int8PtrTy)
// host functions:
// get arguments (mirror of add arg)
__OMP_RTL(__kmpc_host_rpc_get_arg, HM, false, Int64Ty, Int8PtrTy, Int32Ty)
// send the ruturn value
__OMP_RTL(__kmpc_host_rpc_set_ret_val, HM, false, VoidTy, Int8PtrTy,
Int64Ty)
// Invoke the function on the host
__OMP_RTL(__kmpc_host_rpc_invoke_host_wrapper, HM, false, VoidTy, Int32Ty,
Int8PtrTy)
#undef __OMP_RTL
ArgInfoTy = StructType::create({Int64Ty, Int64Ty, Int64Ty, Int8PtrTy},
"struct.arg_info_t");
}
bool run();
};
Value *HostRPC::convertToInt64Ty(Value *V) {
if (auto *C = dyn_cast<Constant>(V))
return convertToInt64Ty(C);
Type *T = V->getType();
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreatePtrToInt(V, Int64Ty);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, Int64Ty, /* isSigned */ true);
if (T->isFloatingPointTy()) {
V = Builder.CreateBitCast(
V, Type::getIntNTy(V->getContext(), T->getScalarSizeInBits()));
return Builder.CreateIntCast(V, Int64Ty, /* isSigned */ true);
}
llvm_unreachable("unknown cast to int64_t");
}
Value *HostRPC::convertFromInt64TyTo(Value *V, Type *T) {
if (auto *C = dyn_cast<Constant>(V))
return convertFromInt64TyTo(C, T);
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreateIntToPtr(V, T);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, T, /* isSigned */ true);
if (T->isFloatingPointTy()) {
V = Builder.CreateIntCast(
V, Type::getIntNTy(V->getContext(), T->getScalarSizeInBits()),
/* isSigned */ true);
return Builder.CreateBitCast(V, T);
}
LLVM_DEBUG(dbgs() << "[HostRPC] unknown type " << *T
<< " for typeFromint64_t.\n";);
llvm_unreachable("unknown cast from int64_t");
}
Constant *HostRPC::convertToInt64Ty(Constant *C) {
Type *T = C->getType();
if (T == Int64Ty)
return C;
if (T->isPointerTy())
return ConstantExpr::getPtrToInt(C, Int64Ty);
if (T->isIntegerTy()) {
return ConstantFoldIntegerCast(C, Int64Ty, true, DL);
}
if (T->isFloatingPointTy()) {
// cast to an int of the same size
C = ConstantExpr::getBitCast(C,
Type::getIntNTy(C->getContext(), T->getScalarSizeInBits()));
// set the int of size 64
return ConstantFoldIntegerCast(C, Int64Ty, true, DL);
}
llvm_unreachable("unknown cast to int64_t");
}
Constant *HostRPC::convertFromInt64TyTo(Constant *C, Type *T) {
assert(C->getType() == Int64Ty);
if (T == Int64Ty)
return C;
if (T->isPointerTy())
return ConstantExpr::getIntToPtr(C, T);
if (T->isIntegerTy()) {
return ConstantFoldIntegerCast(C, T, true, DL);
}
if (T->isFloatingPointTy()) {
// change size to T size
C = ConstantFoldIntegerCast(C,
Type::getIntNTy(C->getContext(), T->getScalarSizeInBits()), true, DL);
// from int to float
return ConstantExpr::getBitCast(C, T);
}
llvm_unreachable("unknown cast from int64_t");
}
void HostRPC::registerAAs() {
for (auto *F : FunctionWorkList)
for (User *U : F->users()) {
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
for (unsigned I = 0; I < CI->arg_size(); ++I) {
Value *Operand = CI->getArgOperand(I);
if (!Operand->getType()->isPointerTy())
continue;
A.getOrCreateAAFor<AAUnderlyingObjects>(
IRPosition::callsite_argument(*CI, I),
/* QueryingAA */ nullptr, DepClassTy::NONE);
}
}
}
bool HostRPC::recollectInformation() {
FunctionWorkList.clear();
for (Function &F : M) {
// If the function is already defined, it definitely does not require RPC.
if (!F.isDeclaration())
continue;
// If it is an internal function, skip it as well.
if (isInternalFunction(F))
continue;
// If there is no use of the function, skip it.
if (F.use_empty())
continue;
LLVM_DEBUG({
dbgs() << "[HostRPC] RPCing function: " << F.getName() << "\n"
<< F << "\n";
});
FunctionWorkList.insert(&F);
}
return !FunctionWorkList.empty();
}
bool HostRPC::run() {
bool Changed = false;
LLVM_DEBUG(dbgs() << "[HostRPC] Running Pass\n");
if (!recollectInformation())
return Changed;
Changed = true;
// We add a couple of assumptions to those RPC functions such that AAs will
// not error out because of unknown implementation of those functions.
for (Function &F : M) {
if (!F.isDeclaration())
continue;
F.addFnAttr(Attribute::NoRecurse);
for (auto &Arg : F.args())
if (Arg.getType()->isPointerTy())
Arg.addAttr(Attribute::NoCapture);
if (!F.isVarArg())
continue;
for (User *U : F.users()) {
auto *CB = dyn_cast<CallBase>(U);
if (!CB)
continue;
for (unsigned I = F.getFunctionType()->getNumParams(); I < CB->arg_size();
++I) {
Value *Arg = CB->getArgOperand(I);
if (Arg->getType()->isPointerTy())
CB->addParamAttr(I, Attribute::NoCapture);
}
}
}
//LLVM_DEBUG(M.dump());
registerAAs();
ChangeStatus Status = A.run();
if (!recollectInformation())
return Status == ChangeStatus::CHANGED;
for (Function *F : FunctionWorkList)
Changed |= rewriteWithHostRPC(F);
if (!Changed)
return Changed;
// replace all call to the function to a call to the rpc wrapper that have replace it.
for (auto Itr = CallInstMap.rbegin(); Itr != CallInstMap.rend(); ++Itr) {
auto *CI = Itr->first;
auto *NewCI = Itr->second;
CI->replaceAllUsesWith(NewCI);
CI->eraseFromParent();
}
// erase all trace of the function in the Module
for (Function *F : FunctionWorkList)
if (F->user_empty())
F->eraseFromParent();
emitHostWrapperInvoker();
return Changed;
}
bool HostRPC::rewriteWithHostRPC(Function *F) {
SmallVector<CallInst *> WorkList;
for (User *U : F->users()) {
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
WorkList.push_back(CI);
}
if (WorkList.empty())
return false;
for (CallInst *CI : WorkList) {
CallSiteInfo CSI;
CSI.CI = CI;
unsigned NumArgs = CI->arg_size();
for (unsigned I = 0; I < NumArgs; ++I)
CSI.Params.push_back(CI->getArgOperand(I)->getType());
std::string WrapperName = getWrapperFunctionName(F, CSI);
Function *DeviceWrapperFn = getDeviceWrapperFunction(WrapperName, F, CSI);
Function *HostWrapperFn = getHostWrapperFunction(WrapperName, F, CSI);
int32_t WrapperNumber = -1;
for (unsigned I = 0; I < HostEntryTable.size(); ++I) {
if (HostEntryTable[I] == HostWrapperFn) {
WrapperNumber = I;
break;
}
}
if (WrapperNumber == -1) {
WrapperNumber = HostEntryTable.size();
HostEntryTable.push_back(HostWrapperFn);
}
auto CheckIfIdentifierPtr = [this](const Value *V) {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return false;
Function *Callee = CI->getCalledFunction();
if (this->FunctionWorkList.count(Callee))
return true;
return Callee->getName().starts_with("__kmpc_host_rpc_wrapper_");
};
auto CheckIfDynAlloc = [](Value *V) -> CallInst * {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return nullptr;
Function *Callee = CI->getCalledFunction();
auto Name = Callee->getName();
if (Name == "malloc" || Name == "__kmpc_alloc_shared")
return CI;
return nullptr;
};
auto CheckIfStdIO = [](Value *V) -> GlobalVariable * {
auto *LI = dyn_cast<LoadInst>(V);
if (!LI)
return nullptr;
auto *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand());
if (!GV)
return nullptr;
auto Name = GV->getName();
if (Name == "stdout" || Name == "stderr" || Name == "stdin")
return GV;
return nullptr;
};
auto CheckIfGlobalVariable = [](Value *V) {
if (auto *GV = dyn_cast<GlobalVariable>(V))
return GV;
if (auto *LI = dyn_cast<LoadInst>(V))
if (auto *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand()))
return GV;
return static_cast<GlobalVariable *>(nullptr);
};
auto CheckIfNullPtr = [](Value *V) {
if (!V->getType()->isPointerTy())
return false;
return V == ConstantInt::getNullValue(V->getType());
};
auto HandleDirectUse = [&](Value *Ptr, HostRPCArgInfo &AI,
bool IsPointer = false) {
AI.BasePtr = Ptr;
AI.Type = getConstantInt64(IsPointer ? ArgType::OMP_HOST_RPC_ARG_PTR
: ArgType::OMP_HOST_RPC_ARG_SCALAR);
AI.Size = NullInt64;
};
SmallVector<SmallVector<HostRPCArgInfo>> ArgInfo;
bool IsConstantArgInfo = true;
for (unsigned I = 0; I < CI->arg_size(); ++I) {
ArgInfo.emplace_back();
auto &AII = ArgInfo.back();
Value *Operand = CI->getArgOperand(I);
LLVM_DEBUG({dbgs() << "[HostRPC] [argparse]: Argument: " << I << ": " << *Operand << "\n"; });
// Check if scalar type.
if (!Operand->getType()->isPointerTy()) {
AII.emplace_back();
HandleDirectUse(Operand, AII.back());
IsConstantArgInfo = IsConstantArgInfo && isa<Constant>(Operand);
LLVM_DEBUG({dbgs() << "[HostRPC] [argparse]: Constant: " << *Operand << "\n"; });
continue;
}
if (CheckIfNullPtr(Operand)){
LLVM_DEBUG({dbgs() << "[HostRPC] [argparse]: Null Ptr: " << *Operand << "\n"; });
continue;
}
auto Pred = [&](Value &Obj) {
if (CheckIfNullPtr(&Obj))
return true;
bool IsConstantArgument = false;
if (!F->isVarArg() &&
F->hasParamAttribute(I, Attribute::AttrKind::ReadOnly))
IsConstantArgument = true;
HostRPCArgInfo AI;
if (auto *IO = CheckIfStdIO(&Obj)) {
HandleDirectUse(IO, AI, /* IsPointer */ true);
} else if (CheckIfIdentifierPtr(&Obj)) {
IsConstantArgInfo = IsConstantArgInfo && isa<Constant>(Operand);
HandleDirectUse(Operand, AI, /* IsPointer */ true);
} else if (auto *GV = CheckIfGlobalVariable(&Obj)) {
AI.BasePtr = GV;
AI.Size = getConstantInt64(DL.getTypeStoreSize(GV->getValueType()));
AI.Type =
getConstantInt64(GV->isConstant() || IsConstantArgument
? ArgType::OMP_HOST_RPC_ARG_COPY_TO
: ArgType::OMP_HOST_RPC_ARG_COPY_TOFROM);
} else if (CheckIfDynAlloc(&Obj)) {
// We will handle this case at runtime so here we don't do anything.
LLVM_DEBUG({dbgs() << "[HostRPC] [argparse]: Dynamic Alloc: " << *Operand << "\n"; });
return true;
} else if (isa<AllocaInst>(&Obj)) {
llvm_unreachable("alloca instruction needs to be handled!");
} else {
LLVM_DEBUG({
dbgs() << "[HostRPC] warning: call site " << *CI << ", operand "
<< *Operand << ", underlying object " << Obj
<< " cannot be handled.\n";
});
return true;
}
AII.push_back(std::move(AI));
return true;
};
LLVM_DEBUG({
dbgs() << "[HostRPC] function rewrite:\n"
<< "Function: " << *F << "\n"
<< "Call site: " << *CI << "\n "
<< "Operand: " << *Operand << "\n";
});
// TODO replace with LLVM functions to not use Attributors.
assert(!IRPosition::callsite_argument(*CI, I)
.getAnchorScope()->hasFnAttribute(Attribute::OptimizeNone)
&& "[HostRPC]: Optimize None is not supported");
const llvm::AAUnderlyingObjects* AAUO =
A.getOrCreateAAFor<AAUnderlyingObjects>(
IRPosition::callsite_argument(*CI, I));
LLVM_DEBUG({dbgs() << "[HostRPC] AAUO:" << AAUO << "\n";});
if (!AAUO->forallUnderlyingObjects(Pred))
llvm_unreachable("internal error");
}
// Reset the insert point to the call site.
Builder.SetInsertPoint(CI);
Value *ArgInfoVal = nullptr;
if (!IsConstantArgInfo) {
ArgInfoVal = Builder.CreateAlloca(Int8PtrTy, getConstantInt64(NumArgs),
"arg_info");
for (unsigned I = 0; I < NumArgs; ++I) {
auto &AII = ArgInfo[I];
Value *Next = NullPtr;
for (auto &AI : AII) {
Value *AIV = Builder.CreateAlloca(ArgInfoTy);
Value *AIIArg =
GetElementPtrInst::Create(Int64Ty, AIV, {getConstantInt64(0)});
Builder.Insert(AIIArg);
Builder.CreateStore(convertToInt64Ty(AI.BasePtr), AIIArg);
Value *AIIType =
GetElementPtrInst::Create(Int64Ty, AIV, {getConstantInt64(1)});
Builder.Insert(AIIType);
Builder.CreateStore(AI.Type, AIIType);
Value *AIISize =
GetElementPtrInst::Create(Int64Ty, AIV, {getConstantInt64(2)});
Builder.Insert(AIISize);
Builder.CreateStore(AI.Size, AIISize);
Value *AIINext =
GetElementPtrInst::Create(Int8PtrTy, AIV, {getConstantInt64(3)});
Builder.Insert(AIINext);
Builder.CreateStore(Next, AIINext);
Next = AIV;
}
Value *AIIV = GetElementPtrInst::Create(Int8PtrTy, ArgInfoVal,
{getConstantInt64(I)});
Builder.Insert(AIIV);
Builder.CreateStore(Next, AIIV);
}
} else {
SmallVector<Constant *> ArgInfoInitVar;
for (auto &AII : ArgInfo) {
Constant *Last = NullPtr;
for (auto &AI : AII) {
auto *Arg = cast<Constant>(AI.BasePtr);
auto *CS =
ConstantStruct::get(ArgInfoTy, {convertToInt64Ty(Arg), AI.Type,
cast<Constant>(AI.Size), Last});
auto *GV = new GlobalVariable(
M, ArgInfoTy, /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, CS, "",
nullptr, GlobalValue::ThreadLocalMode::NotThreadLocal, 0);
// force adress space 0 on AMD GPU
// insted of address space 1 for globals
Last = GV;
}
LLVM_DEBUG({
dbgs() << "[HostRPC] ArgInfoInitVar:" << *Last << "\n";
});
ArgInfoInitVar.push_back(Last);
}
Constant *ArgInfoInit = ConstantArray::get(
ArrayType::get(Int8PtrTy, NumArgs), ArgInfoInitVar);
ArgInfoVal = new GlobalVariable(
M, ArrayType::get(Int8PtrTy, NumArgs), /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info",
nullptr, GlobalValue::ThreadLocalMode::NotThreadLocal, 0);
}
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
ArgInfoVal};
for (Value *Operand : CI->args())
Args.push_back(Operand);
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
CallInstMap.emplace_back(CI, NewCall);
}
return true;
}
Function *HostRPC::getDeviceWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI) {
Function *WrapperFn = M.getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
// return_type device_wrapper(int32_t call_no, void *arg_info, ...)
SmallVector<Type *> Params{Int32Ty, Int8PtrTy};
Params.append(CSI.Params);
Type *RetTy = F->getReturnType();
FunctionType *FT = FunctionType::get(RetTy, Params, /*isVarArg*/ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::InternalLinkage,
WrapperName, M);
// Emit the body of the device wrapper
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
// skip call_no and arg_info.
constexpr const unsigned NumArgSkipped = 2;
Value *Desc = nullptr;
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
LLVM_DEBUG({dbgs() << "[HostRPC] Building: rpc get desc: " << *Fn << "\n"; });
for (unsigned i = 0; i < 3; ++i)
LLVM_DEBUG({dbgs() << "ParamI: " << *(Fn->getFunctionType()->getParamType(i)) << "\n"; });
Desc = Builder.CreateCall(
Fn,
{
WrapperFn->getArg(0),
ConstantInt::get(Int32Ty, WrapperFn->arg_size() - NumArgSkipped),
WrapperFn->getArg(1)
},
"desc"
);
}
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_add_arg];
LLVM_DEBUG({dbgs() << "[HostRPC] Building: rpc add arg\n"; });
for (unsigned I = NumArgSkipped; I < WrapperFn->arg_size(); ++I) {
Value *V = convertToInt64Ty(WrapperFn->getArg(I));
Builder.CreateCall(
Fn, {Desc, V, ConstantInt::get(Int32Ty, I - NumArgSkipped)});
}
}
LLVM_DEBUG({dbgs() << "[HostRPC] Building: rpc send and wait\n"; });
Value *RetVal =
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_send_and_wait], {Desc});
if (RetTy->isVoidTy()) {
Builder.CreateRetVoid();
return WrapperFn;
}
if (RetTy != RetVal->getType())
RetVal = convertFromInt64TyTo(RetVal, RetTy);
Builder.CreateRet(RetVal);
LLVM_DEBUG({dbgs() << "[HostRPC] Device Wrapper Function:\n" << *WrapperFn; });
return WrapperFn;
}
Function *HostRPC::getHostWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI) {
Function *WrapperFn = HM.getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
SmallVector<Type *> Params{Int8PtrTy};
FunctionType *FT = FunctionType::get(VoidTy, Params, /* isVarArg */ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage,
WrapperName, HM);
Value *Desc = WrapperFn->getArg(0);
// Emit the body of the host wrapper
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
SmallVector<Value *> Args;
for (unsigned I = 0; I < CSI.CI->arg_size(); ++I) {
Value *V = Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_arg],
{Desc, ConstantInt::get(Int32Ty, I)});
Args.push_back(convertFromInt64TyTo(V, CSI.Params[I]));
}
// The host callee that will be called eventually by the host wrapper.
Function *HostCallee = HM.getFunction(F->getName());
if (!HostCallee)
HostCallee = Function::Create(F->getFunctionType(), F->getLinkage(),
F->getName(), HM);
Value *RetVal = Builder.CreateCall(HostCallee, Args);
if (!RetVal->getType()->isVoidTy()) {
RetVal = convertToInt64Ty(RetVal);
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_set_ret_val],
{Desc, RetVal});
}
Builder.CreateRetVoid();
return WrapperFn;
}
void HostRPC::emitHostWrapperInvoker() {
IRBuilder<> Builder(Context);
unsigned NumEntries = HostEntryTable.size();
Function *F = RFIs[OMPRTL___kmpc_host_rpc_invoke_host_wrapper];
F->setDLLStorageClass(
GlobalValue::DLLStorageClassTypes::DLLExportStorageClass);
Value *CallNo = F->getArg(0);
Value *Desc = F->getArg(1);
SmallVector<BasicBlock *> SwitchBBs;
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", F);
BasicBlock *ReturnBB = BasicBlock::Create(Context, "return", F);
// Emit code for the return bb.
Builder.SetInsertPoint(ReturnBB);
Builder.CreateRetVoid();
// Create BB for each host entry and emit function call.
for (unsigned I = 0; I < NumEntries; ++I) {
BasicBlock *BB = BasicBlock::Create(Context, "invoke.bb", F, ReturnBB);
SwitchBBs.push_back(BB);
Builder.SetInsertPoint(BB);
Builder.CreateCall(HostEntryTable[I], {Desc});
Builder.CreateBr(ReturnBB);
}
// Emit code for the entry BB.
Builder.SetInsertPoint(EntryBB);
SwitchInst *Switch = Builder.CreateSwitch(CallNo, ReturnBB, NumEntries);
for (unsigned I = 0; I < NumEntries; ++I)
Switch->addCase(ConstantInt::get(cast<IntegerType>(Int32Ty), I),
SwitchBBs[I]);
}
Module *getHostModule(Module &M) {
auto *MD = M.getNamedMetadata("llvm.hostrpc.hostmodule");
if (!MD || MD->getNumOperands() == 0)
return nullptr;
auto *Node = MD->getOperand(0);
assert(Node->getNumOperands() == 1 && "invliad named metadata");
auto *CAM = dyn_cast<ConstantAsMetadata>(Node->getOperand(0));
if (!CAM)
return nullptr;
auto *CI = cast<ConstantInt>(CAM->getValue());
Module *Mod = reinterpret_cast<Module *>(CI->getZExtValue());
M.eraseNamedMetadata(MD);
return Mod;
}
} // namespace
PreservedAnalyses HostRPCPass::run(Module &M, ModuleAnalysisManager &AM) {
std::unique_ptr<Module> DummyHostModule;
Module *HostModule = nullptr;
if (UseDummyHostModule) {
DummyHostModule =
std::make_unique<Module>("dummy-host-rpc.bc", M.getContext());
HostModule = DummyHostModule.get();
} else {
HostModule = getHostModule(M);
}
if (!HostModule)
return PreservedAnalyses::all();
bool PostLink = LTOPhase == ThinOrFullLTOPhase::FullLTOPostLink ||
LTOPhase == ThinOrFullLTOPhase::ThinLTOPreLink;
// The pass will not run if it is not invoked directly or not invoked at link
// time.
if (!UseDummyHostModule && !PostLink)
return PreservedAnalyses::all();
FunctionAnalysisManager &FAM =
AM.getResult<FunctionAnalysisManagerModuleProxy>(M).getManager();
AnalysisGetter AG(FAM);
CallGraphUpdater CGUpdater;
BumpPtrAllocator Allocator;
AttributorConfig AC(CGUpdater);
AC.DefaultInitializeLiveInternals = false;
AC.RewriteSignatures = false;
AC.PassName = DEBUG_TYPE;
AC.MaxFixpointIterations = 1024;
InformationCache InfoCache(M, AG, Allocator, /* CGSCC */ nullptr);
SetVector<Function *> Functions;
Attributor A(Functions, InfoCache, AC);
HostRPC RPC(M, *HostModule, A);
bool Changed = RPC.run();
LLVM_DEBUG({
if (Changed && UseDummyHostModule) {
M.dump();
HostModule->dump();
}
});
return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}