Compare commits

...

1 Commits

Author SHA1 Message Date
Shilei Tian
676d2638f3 Initial automatic host rpc 2022-10-24 22:08:03 -04:00
14 changed files with 1285 additions and 0 deletions

65
auto-host-rpc/device.ll Normal file
View File

@@ -0,0 +1,65 @@
; ModuleID = 'test-openmp-nvptx64-nvidia-cuda-sm_75.bc'
source_filename = "/home/shiltian/Documents/vscode/llvm-project/auto-host-rpc/test.c"
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
target triple = "nvptx64-nvidia-cuda"
%struct.ddd = type { i32, i32, float }
@__omp_rtl_debug_kind = weak_odr hidden local_unnamed_addr constant i32 0
@__omp_rtl_assume_teams_oversubscription = weak_odr hidden local_unnamed_addr constant i32 0
@__omp_rtl_assume_threads_oversubscription = weak_odr hidden local_unnamed_addr constant i32 0
@__omp_rtl_assume_no_thread_state = weak_odr hidden local_unnamed_addr constant i32 0
@__omp_rtl_assume_no_nested_parallelism = weak_odr hidden local_unnamed_addr constant i32 0
@.str = private unnamed_addr constant [9 x i8] c"main.cpp\00", align 1
@.str1 = private unnamed_addr constant [2 x i8] c"r\00", align 1
@.str2 = private unnamed_addr constant [3 x i8] c"%d\00", align 1
@.str3 = private unnamed_addr constant [7 x i8] c"%f%d%s\00", align 1
@.str4 = private unnamed_addr constant [6 x i8] c"hello\00", align 1
; Function Attrs: convergent nounwind
define hidden void @foo() local_unnamed_addr #0 {
entry:
%d = tail call align 16 dereferenceable_or_null(12) ptr @__kmpc_alloc_shared(i64 12) #4
%call = tail call noalias ptr @fopen(ptr noundef nonnull @.str, ptr noundef nonnull @.str1) #5
%call1 = tail call i32 (ptr, ptr, ...) @fprintf(ptr noundef %call, ptr noundef nonnull @.str2, i32 noundef 6) #5
%call2 = tail call i32 (ptr, ptr, ...) @fprintf(ptr noundef %call, ptr noundef nonnull @.str3, double noundef 6.000000e+00, i32 noundef 1, ptr noundef nonnull @.str4) #5
%a = getelementptr inbounds %struct.ddd, ptr %d, i64 0, i32 1
%call3 = tail call i32 (ptr, ptr, ...) @fscanf(ptr noundef %call, ptr noundef nonnull @.str2, ptr noundef nonnull %a) #5
tail call void @__kmpc_free_shared(ptr %d, i64 12)
ret void
}
; Function Attrs: nofree nosync nounwind allocsize(0)
declare ptr @__kmpc_alloc_shared(i64) local_unnamed_addr #1
; Function Attrs: convergent
declare noalias ptr @fopen(ptr noundef, ptr noundef) local_unnamed_addr #2
; Function Attrs: convergent
declare i32 @fprintf(ptr noundef, ptr noundef, ...) local_unnamed_addr #2
; Function Attrs: convergent
declare i32 @fscanf(ptr noundef, ptr noundef, ...) local_unnamed_addr #2
; Function Attrs: nosync nounwind
declare void @__kmpc_free_shared(ptr allocptr nocapture, i64) local_unnamed_addr #3
attributes #0 = { convergent nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="sm_75" "target-features"="+ptx77,+sm_75" }
attributes #1 = { nofree nosync nounwind allocsize(0) }
attributes #2 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="sm_75" "target-features"="+ptx77,+sm_75" }
attributes #3 = { nosync nounwind }
attributes #4 = { nounwind }
attributes #5 = { convergent nounwind }
!llvm.module.flags = !{!0, !1, !2, !3, !4, !5}
!llvm.ident = !{!6, !7}
!nvvm.annotations = !{}
!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 11, i32 7]}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{i32 7, !"openmp", i32 50}
!3 = !{i32 7, !"openmp-device", i32 50}
!4 = !{i32 8, !"PIC Level", i32 2}
!5 = !{i32 7, !"frame-pointer", i32 2}
!6 = !{!"clang version 16.0.0"}
!7 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}

560
auto-host-rpc/main.cpp Normal file
View File

@@ -0,0 +1,560 @@
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetOptions.h"
#include <string>
using namespace llvm;
namespace {
static LLVMContext Context;
static codegen::RegisterCodeGenFlags RCGF;
static constexpr const char *InternalPrefix[] = {"__kmp", "llvm.", "nvm.",
"omp_"};
std::string typeToString(Type *T) {
if (T->is16bitFPTy())
return "f16";
if (T->isFloatTy())
return "f32";
if (T->isDoubleTy())
return "f64";
if (T->isPointerTy())
return "ptr";
if (T->isStructTy())
return std::string(T->getStructName());
if (T->isIntegerTy())
return "i" + std::to_string(T->getIntegerBitWidth());
llvm_unreachable("unknown type");
}
} // namespace
namespace llvm {
enum class HostRPCRuntimeFunction {
#define __OMPRTL_HOST_RPC(_ENUM) OMPRTL_##_ENUM
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val),
__OMPRTL_HOST_RPC(__last),
#undef __OMPRTL_HOST_RPC
};
#define __OMPRTL_HOST_RPC(_ENUM) \
auto OMPRTL_##_ENUM = HostRPCRuntimeFunction::OMPRTL_##_ENUM;
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val)
#undef __OMPRTL_HOST_RPC
enum OMPHostRPCArgType {
// No need to copy.
OMP_HOST_RPC_ARG_SCALAR = 0,
OMP_HOST_RPC_ARG_PTR = 1,
// Copy to host.
OMP_HOST_RPC_ARG_PTR_COPY_TO = 2,
// Copy to device
OMP_HOST_RPC_ARG_PTR_COPY_FROM = 3,
// TODO: Do we have a tofrom pointer?
OMP_HOST_RPC_ARG_PTR_COPY_TOFROM = 4,
};
// struct HostRPCArgInfo {
// // OMPHostRPCArgType
// int64_t Type;
// int64_t Size;
// };
class AutoHostRPC {
LLVMContext &Context;
// Device module
Module &DM;
// Host module
Module &HM;
// Types
Type *Int8PtrTy;
Type *VoidTy;
Type *Int32Ty;
Type *Int64Ty;
StructType *ArgInfoTy;
struct CallSiteInfo {
CallInst *CI = nullptr;
SmallVector<Type *> Params;
};
struct HostRPCArgInfo {
// OMPHostRPCArgType
Constant *Type;
Value *Size;
};
//
SmallVector<Function *> HostEntryTable;
EnumeratedArray<Function *, HostRPCRuntimeFunction,
HostRPCRuntimeFunction::OMPRTL___last>
RFIs;
static std::string getWrapperFunctionName(Function *F, CallSiteInfo &CSI) {
std::string Name = "__kmpc_host_rpc_wrapper_" + std::string(F->getName());
if (!F->isVarArg())
return Name;
for (unsigned I = F->getFunctionType()->getNumParams();
I < CSI.Params.size(); ++I) {
Name.push_back('_');
Name.append(typeToString(CSI.Params[I]));
}
return Name;
}
static bool isInternalFunction(Function &F) {
auto Name = F.getName();
for (auto P : InternalPrefix)
if (Name.startswith(P))
return true;
return false;
}
Value *convertToInt64Ty(IRBuilder<> &Builder, Value *V);
Value *convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *TargetTy);
// int device_wrapper(call_no, arg_info, ...) {
// void *desc = __kmpc_host_rpc_get_desc(call_no, num_args, arg_info);
// __kmpc_host_rpc_add_arg(desc, arg1, sizeof(arg1));
// __kmpc_host_rpc_add_arg(desc, arg2, sizeof(arg2));
// ...
// __kmpc_host_rpc_send_and_wait(desc);
// int r = (int)__kmpc_host_rpc_get_ret_val(desc);
// return r;
// }
Function *getDeviceWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
// void host_wrapper(desc) {
// int arg1 = (int)__kmpc_host_rpc_get_arg(desc, 0);
// float arg2 = (float)__kmpc_host_rpc_get_arg(desc, 1);
// char *arg3 = (char *)__kmpc_host_rpc_get_arg(desc, 2);
// ...
// int r = actual_call(arg1, arg2, arg3, ...);
// __kmpc_host_rpc_set_ret_val(ptr(desc, (int64_t)r);
// }
Function *getHostWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
bool rewriteWithHostRPC(Function *F);
public:
AutoHostRPC(Module &DeviceModule, Module &HostModule)
: Context(DeviceModule.getContext()), DM(DeviceModule), HM(HostModule) {
assert(&DeviceModule.getContext() == &HostModule.getContext() &&
"device and host modules have different context");
#define __OMP_TYPE(TYPE) TYPE = Type::get##TYPE(Context)
__OMP_TYPE(Int8PtrTy);
__OMP_TYPE(VoidTy);
__OMP_TYPE(Int32Ty);
__OMP_TYPE(Int64Ty);
#undef __OMP_TYPE
#define __OMP_RTL(_ENUM, MOD, VARARG, RETTY, ...) \
{ \
SmallVector<Type *> Params{__VA_ARGS__}; \
FunctionType *FT = FunctionType::get(RETTY, Params, VARARG); \
RFIs[OMPRTL_##_ENUM] = Function::Create( \
FT, GlobalValue::LinkageTypes::InternalLinkage, #_ENUM, MOD); \
}
__OMP_RTL(__kmpc_host_rpc_get_desc, DM, false, Int8PtrTy, Int32Ty, Int32Ty,
Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_add_arg, DM, false, VoidTy, Int8PtrTy, Int64Ty,
Int64Ty)
__OMP_RTL(__kmpc_host_rpc_send_and_wait, DM, false, VoidTy, Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_get_ret_val, DM, false, Int64Ty, Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_get_arg, HM, false, Int64Ty, Int8PtrTy, Int32Ty)
__OMP_RTL(__kmpc_host_rpc_set_ret_val, HM, false, VoidTy, Int8PtrTy,
Int64Ty)
#undef __OMP_RTL
ArgInfoTy = StructType::create({Int64Ty, Int64Ty}, "struct.arg_info_t");
}
bool run();
};
Value *AutoHostRPC::convertToInt64Ty(IRBuilder<> &Builder, Value *V) {
Type *T = V->getType();
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreatePtrToInt(V, Int64Ty);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, Int64Ty, false);
if (T->isFloatingPointTy()) {
if (T->isFloatTy())
V = Builder.CreateFPToSI(V, Int32Ty);
return Builder.CreateFPToSI(V, Int64Ty);
}
llvm_unreachable("unknown cast to int64_t");
}
Value *AutoHostRPC::convertFromInt64TyTo(IRBuilder<> &Builder, Value *V,
Type *T) {
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreateIntToPtr(V, T);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, T, /* isSigned */ true);
if (T->isFloatingPointTy()) {
if (T->isFloatTy())
V = Builder.CreateIntCast(V, Int32Ty, /* isSigned */ true);
V = Builder.CreateSIToFP(V, T);
return V;
}
llvm_unreachable("unknown cast from int64_t");
}
bool AutoHostRPC::run() {
bool Changed = false;
SmallVector<Function *> WorkList;
for (Function &F : DM) {
// If the function is already defined, it definitely does not require RPC.
if (!F.isDeclaration())
continue;
// If it is an internal function, skip it as well.
if (isInternalFunction(F))
continue;
// If there is no use of the function, skip it.
if (F.use_empty())
continue;
WorkList.push_back(&F);
}
if (WorkList.empty())
return Changed;
for (Function *F : WorkList)
Changed |= rewriteWithHostRPC(F);
return Changed;
}
bool AutoHostRPC::rewriteWithHostRPC(Function *F) {
bool Changed = false;
SmallVector<CallInst *> WorkList;
for (User *U : F->users()) {
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
WorkList.push_back(CI);
}
if (WorkList.empty())
return Changed;
for (CallInst *CI : WorkList) {
CallSiteInfo CSI;
CSI.CI = CI;
unsigned NumArgs = CI->arg_size();
for (unsigned I = 0; I < NumArgs; ++I)
CSI.Params.push_back(CI->getArgOperand(I)->getType());
std::string WrapperName = getWrapperFunctionName(F, CSI);
Function *DeviceWrapperFn = getDeviceWrapperFunction(WrapperName, F, CSI);
Function *HostWrapperFn = getHostWrapperFunction(WrapperName, F, CSI);
int32_t WrapperNumber = -1;
for (unsigned I = 0; I < HostEntryTable.size(); ++I) {
if (HostEntryTable[I] == HostWrapperFn) {
WrapperNumber = I;
break;
}
}
if (WrapperNumber == -1) {
WrapperNumber = HostEntryTable.size();
HostEntryTable.push_back(HostWrapperFn);
}
DataLayout DL = DM.getDataLayout();
IRBuilder<> Builder(CI);
auto CheckIfIdentifierPtr = [](const Value *V) {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return false;
Function *Callee = CI->getCalledFunction();
return Callee->getName().startswith("__kmpc_host_rpc_wrapper_");
};
auto CheckIfAlloca = [](const Value *V) {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return false;
Function *Callee = CI->getCalledFunction();
return Callee->getName() == "__kmpc_alloc_shared" ||
Callee->getName() == "malloc";
};
SmallVector<HostRPCArgInfo> ArgInfos;
bool IsConstantArgInfo = true;
for (Value *Op : CI->args()) {
if (!Op->getType()->isPointerTy()) {
HostRPCArgInfo AI{
ConstantInt::get(Int64Ty,
OMPHostRPCArgType::OMP_HOST_RPC_ARG_SCALAR),
ConstantInt::getNullValue(Int64Ty)};
ArgInfos.push_back(std::move(AI));
continue;
}
Value *SizeVal = nullptr;
OMPHostRPCArgType ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TOFROM;
SmallVector<const Value *> Objects;
getUnderlyingObjects(Op, Objects);
// TODO: Handle phi node
if (Objects.size() != 1)
llvm_unreachable("we can't handle phi node yet");
auto *Obj = Objects.front();
if (CheckIfIdentifierPtr(Obj)) {
ArgType = OMP_HOST_RPC_ARG_SCALAR;
SizeVal = ConstantInt::getNullValue(Int64Ty);
} else if (CheckIfAlloca(Obj)) {
auto *CI = dyn_cast<CallInst>(Obj);
SizeVal = CI->getOperand(0);
if (!isa<Constant>(SizeVal))
IsConstantArgInfo = false;
} else {
if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
SizeVal = ConstantInt::get(Int64Ty,
DL.getTypeStoreSize(GV->getValueType()));
if (GV->isConstant())
ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TO;
if (GV->isConstant() && GV->hasInitializer()) {
// TODO: If the global variable is contant, we can do some
// optimization.
}
} else {
// TODO: fix that when it occurs
llvm_unreachable("cannot handle unknown type");
}
}
HostRPCArgInfo AI{ConstantInt::get(Int64Ty, ArgType), SizeVal};
ArgInfos.push_back(std::move(AI));
}
Value *ArgInfo = nullptr;
if (!IsConstantArgInfo) {
ArgInfo = Builder.CreateAlloca(
ArgInfoTy, ConstantInt::get(Int64Ty, NumArgs), "arg_info");
for (unsigned I = 0; I < NumArgs; ++I) {
Value *AII = GetElementPtrInst::Create(
ArrayType::get(ArgInfoTy, NumArgs), ArgInfo,
{ConstantInt::getNullValue(Int64Ty), ConstantInt::get(Int64Ty, I)});
Value *AIIType = GetElementPtrInst::Create(
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 0)});
Value *AIISize = GetElementPtrInst::Create(
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 1)});
Builder.Insert(AII);
Builder.Insert(AIIType);
Builder.Insert(AIISize);
Builder.CreateStore(ArgInfos[I].Type, AIIType);
Builder.CreateStore(ArgInfos[I].Size, AIISize);
}
} else {
SmallVector<Constant *> ArgInfoInitVar;
for (auto &AI : ArgInfos) {
auto *CS =
ConstantStruct::get(ArgInfoTy, {AI.Type, cast<Constant>(AI.Size)});
ArgInfoInitVar.push_back(CS);
}
Constant *ArgInfoInit = ConstantArray::get(
ArrayType::get(ArgInfoTy, NumArgs), ArgInfoInitVar);
ArgInfo = new GlobalVariable(
DM, ArrayType::get(ArgInfoTy, NumArgs), /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info");
}
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
ArgInfo};
for (Value *Op : CI->args())
Args.push_back(Op);
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
CI->replaceAllUsesWith(NewCall);
CI->eraseFromParent();
}
F->eraseFromParent();
return true;
}
Function *AutoHostRPC::getDeviceWrapperFunction(StringRef WrapperName,
Function *F,
CallSiteInfo &CSI) {
Function *WrapperFn = DM.getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
// return_type device_wrapper(int32_t call_no, void *arg_info, ...)
SmallVector<Type *> Params{Int32Ty, Int8PtrTy};
Params.append(CSI.Params);
Type *RetTy = F->getReturnType();
FunctionType *FT = FunctionType::get(RetTy, Params, /*isVarArg*/ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::WeakODRLinkage,
WrapperName, DM);
// Emit the body of the device wrapper
IRBuilder<> Builder(Context);
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
// skip call_no and arg_info.
constexpr const unsigned NumArgSkipped = 2;
Value *Desc = nullptr;
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
Desc = Builder.CreateCall(
Fn,
{WrapperFn->getArg(0),
ConstantInt::get(Int32Ty, WrapperFn->arg_size() - NumArgSkipped),
WrapperFn->getArg(1)},
"desc");
}
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_add_arg];
for (unsigned I = NumArgSkipped; I < WrapperFn->arg_size(); ++I) {
Value *V = convertToInt64Ty(Builder, WrapperFn->getArg(I));
Builder.CreateCall(Fn, {Desc, V, ConstantInt::getNullValue(Int64Ty)});
}
}
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_send_and_wait], {Desc});
if (RetTy->isVoidTy()) {
Builder.CreateRetVoid();
return WrapperFn;
}
Value *RetVal =
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_ret_val], {Desc});
if (RetTy != RetVal->getType())
RetVal = convertFromInt64TyTo(Builder, RetVal, RetTy);
Builder.CreateRet(RetVal);
return WrapperFn;
}
Function *AutoHostRPC::getHostWrapperFunction(StringRef WrapperName,
Function *F, CallSiteInfo &CSI) {
Function *WrapperFn = HM.getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
SmallVector<Type *> Params{Int8PtrTy};
FunctionType *FT = FunctionType::get(VoidTy, Params, /* isVarArg */ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage,
WrapperName, HM);
Value *Desc = WrapperFn->getArg(0);
// Emit the body of the host wrapper
IRBuilder<> Builder(Context);
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
SmallVector<Value *> Args;
for (unsigned I = 0; I < CSI.CI->arg_size(); ++I) {
Value *V = Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_arg],
{Desc, ConstantInt::get(Int32Ty, I)});
Args.push_back(convertFromInt64TyTo(Builder, V, CSI.Params[I]));
}
// The host callee that will be called eventually by the host wrapper.
Function *HostCallee = HM.getFunction(F->getName());
if (!HostCallee)
HostCallee = Function::Create(F->getFunctionType(), F->getLinkage(),
F->getName(), HM);
Value *RetVal = Builder.CreateCall(HostCallee, Args);
RetVal = convertToInt64Ty(Builder, RetVal);
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_set_ret_val], {Desc, RetVal});
Builder.CreateRetVoid();
return WrapperFn;
}
} // namespace llvm
int main(int argc, char *argv[]) {
InitializeAllTargets();
InitializeAllTargetMCs();
InitializeAllAsmPrinters();
InitializeAllAsmParsers();
SMDiagnostic Err;
std::unique_ptr<Module> DM = parseIRFile("device.ll", Err, Context);
if (!DM)
return 1;
Module HM("host-rpc.bc", Context);
// get the right target triple
HM.setTargetTriple(Triple::normalize("x86-64"));
AutoHostRPC RPC(*DM, HM);
(void)RPC.run();
DM->dump();
// HM.dump();
return 0;
}

19
auto-host-rpc/test.c Normal file
View File

@@ -0,0 +1,19 @@
#include <stdio.h>
#pragma omp begin declare target device_type(nohost)
struct ddd {
int num;
int a;
float b;
};
void foo() {
FILE *fp = fopen("main.cpp", "r");
struct ddd d;
fprintf(fp, "%d", 6);
fprintf(fp, "%f%d%s", 6.0f, 1, "hello");
fscanf(fp, "%d", &d.a);
}
#pragma omp end declare target

View File

@@ -57,6 +57,8 @@ using namespace llvm;
using namespace llvm::opt;
using namespace llvm::object;
extern "C" int printf(const char *, ...);
/// Path of the current binary.
static const char *LinkerExecutable;
@@ -86,6 +88,8 @@ static std::atomic<bool> LTOError;
using OffloadingImage = OffloadBinary::OffloadingImage;
Module *HostModule = nullptr;
namespace llvm {
// Provide DenseMapInfo so that OffloadKind can be used in a DenseMap.
template <> struct DenseMapInfo<OffloadKind> {
@@ -774,6 +778,14 @@ Error linkBitcodeFiles(SmallVectorImpl<OffloadFile> &InputFiles,
? createLTO(Args, Features, OutputBitcode)
: createLTO(Args, Features);
LLVMContext &Ctx = LTOBackend->getContext();
std::unique_ptr<Module> HostModulePtr =
std::make_unique<Module>("host-rpc.bc", Ctx);
HostModule = HostModulePtr.get();
HostModulePtr->setTargetTriple(
Args.getLastArgValue(OPT_host_triple_EQ, sys::getDefaultTargetTriple()));
printf("[linker-wrapper] HostModule=%p\n", HostModule);
// We need to resolve the symbols so the LTO backend knows which symbols need
// to be kept or can be internalized. This is a simplified symbol resolution
// scheme to approximate the full resolution a linker would do.
@@ -863,6 +875,10 @@ Error linkBitcodeFiles(SmallVectorImpl<OffloadFile> &InputFiles,
if (Error Err = LTOBackend->run(AddStream))
return Err;
// Reset the HostModule pointer.
HostModulePtr.reset();
HostModule = nullptr;
if (LTOError)
return createStringError(inconvertibleErrorCode(),
"Errors encountered inside the LTO pipeline.");

View File

@@ -281,6 +281,9 @@ public:
/// by LTO but might not be visible from bitcode symbol table.
static ArrayRef<const char*> getRuntimeLibcallSymbols();
/// Returns the context.
LLVMContext &getContext() { return RegularLTO.Ctx; }
private:
Config Conf;

View File

@@ -0,0 +1,24 @@
//===- Transform/Utils/HostRPC.h - Code of automatic host rpc ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_TRANSFORMS_UTILS_HOSTRPC_H
#define LLVM_TRANSFORMS_UTILS_HOSTRPC_H
#include "llvm/IR/PassManager.h"
namespace llvm {
class HostRPCPass : public PassInfoMixin<HostRPCPass> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};
} // namespace llvm
#endif // LLVM_TRANSFORMS_UTILS_HOSTRPC_H

View File

@@ -228,6 +228,7 @@
#include "llvm/Transforms/Utils/EntryExitInstrumenter.h"
#include "llvm/Transforms/Utils/FixIrreducible.h"
#include "llvm/Transforms/Utils/HelloWorld.h"
#include "llvm/Transforms/Utils/HostRPC.h"
#include "llvm/Transforms/Utils/InjectTLIMappings.h"
#include "llvm/Transforms/Utils/InstructionNamer.h"
#include "llvm/Transforms/Utils/LCSSA.h"

View File

@@ -118,6 +118,7 @@
#include "llvm/Transforms/Utils/AddDiscriminators.h"
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
#include "llvm/Transforms/Utils/CanonicalizeAliases.h"
#include "llvm/Transforms/Utils/HostRPC.h"
#include "llvm/Transforms/Utils/InjectTLIMappings.h"
#include "llvm/Transforms/Utils/LibCallsShrinkWrap.h"
#include "llvm/Transforms/Utils/Mem2Reg.h"
@@ -190,6 +191,9 @@ static cl::opt<bool> EnableGlobalAnalyses(
"enable-global-analyses", cl::init(true), cl::Hidden,
cl::desc("Enable inter-procedural analyses"));
cl::opt<bool> EnableHostRPC("enable-host-rpc", cl::init(false), cl::Hidden,
cl::desc("Enable HostRPC pass"));
PipelineTuningOptions::PipelineTuningOptions() {
LoopInterleaving = true;
LoopVectorization = true;
@@ -918,6 +922,9 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
PGOIndirectCallPromotion(true /* IsInLTO */, true /* SamplePGO */));
}
if (EnableHostRPC)
MPM.addPass(HostRPCPass());
// Try to perform OpenMP specific optimizations on the module. This is a
// (quick!) no-op if there are no OpenMP runtime calls present in the module.
if (Level != OptimizationLevel::O0)

View File

@@ -127,6 +127,7 @@ MODULE_PASS("sanmd-module", SanitizerBinaryMetadataPass())
MODULE_PASS("memprof-module", ModuleMemProfilerPass())
MODULE_PASS("poison-checking", PoisonCheckingPass())
MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass())
MODULE_PASS("host-rpc", HostRPCPass())
#undef MODULE_PASS
#ifndef MODULE_PASS_WITH_PARAMS

View File

@@ -29,6 +29,7 @@ add_llvm_component_library(LLVMTransformUtils
GlobalStatus.cpp
GuardUtils.cpp
HelloWorld.cpp
HostRPC.cpp
InlineFunction.cpp
InjectTLIMappings.cpp
InstructionNamer.cpp

View File

@@ -0,0 +1,553 @@
//===- Transform/Utils/HostRPC.h - Code of automatic host rpc ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/HostRPC.h"
#include "llvm/ADT/EnumeratedArray.h"
#include "llvm/ADT/Triple.h"
#include "llvm/Analysis/ValueTracking.h"
#include "llvm/CodeGen/CommandFlags.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/LLVMContext.h"
#include "llvm/IR/Module.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/TargetSelect.h"
#include "llvm/Target/TargetOptions.h"
using namespace llvm;
#define DEBUG_TYPE "host-rpc"
extern "C" int printf(const char *, ...);
__attribute__((weak)) Module *HostModule = nullptr;
namespace {
// TODO: Remove those functions implemented in device runtime.
static constexpr const char *InternalPrefix[] = {
"__kmp", "llvm.", "nvm.", "omp_", "vprintf", "malloc", "free"};
std::string typeToString(Type *T) {
if (T->is16bitFPTy())
return "f16";
if (T->isFloatTy())
return "f32";
if (T->isDoubleTy())
return "f64";
if (T->isPointerTy())
return "ptr";
if (T->isStructTy())
return std::string(T->getStructName());
if (T->isIntegerTy())
return "i" + std::to_string(T->getIntegerBitWidth());
llvm_unreachable("unknown type");
}
enum class HostRPCRuntimeFunction {
#define __OMPRTL_HOST_RPC(_ENUM) OMPRTL_##_ENUM
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val),
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val),
__OMPRTL_HOST_RPC(__last),
#undef __OMPRTL_HOST_RPC
};
#define __OMPRTL_HOST_RPC(_ENUM) \
auto OMPRTL_##_ENUM = HostRPCRuntimeFunction::OMPRTL_##_ENUM;
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val)
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val)
#undef __OMPRTL_HOST_RPC
enum OMPHostRPCArgType {
// No need to copy.
OMP_HOST_RPC_ARG_SCALAR = 0,
OMP_HOST_RPC_ARG_PTR = 1,
// Copy to host.
OMP_HOST_RPC_ARG_PTR_COPY_TO = 2,
// Copy to device
OMP_HOST_RPC_ARG_PTR_COPY_FROM = 3,
// TODO: Do we have a tofrom pointer?
OMP_HOST_RPC_ARG_PTR_COPY_TOFROM = 4,
};
class HostRPC {
LLVMContext &Context;
// Device module
Module &M;
// Types
Type *Int8PtrTy;
Type *VoidTy;
Type *Int32Ty;
Type *Int64Ty;
StructType *ArgInfoTy;
struct CallSiteInfo {
CallInst *CI = nullptr;
SmallVector<Type *> Params;
};
struct HostRPCArgInfo {
// OMPHostRPCArgType
Constant *Type;
Value *Size;
};
//
SmallVector<Function *> HostEntryTable;
EnumeratedArray<Function *, HostRPCRuntimeFunction,
HostRPCRuntimeFunction::OMPRTL___last>
RFIs;
static std::string getWrapperFunctionName(Function *F, CallSiteInfo &CSI) {
std::string Name = "__kmpc_host_rpc_wrapper_" + std::string(F->getName());
if (!F->isVarArg())
return Name;
for (unsigned I = F->getFunctionType()->getNumParams();
I < CSI.Params.size(); ++I) {
Name.push_back('_');
Name.append(typeToString(CSI.Params[I]));
}
return Name;
}
static bool isInternalFunction(Function &F) {
auto Name = F.getName();
for (auto P : InternalPrefix)
if (Name.startswith(P))
return true;
return false;
}
Value *convertToInt64Ty(IRBuilder<> &Builder, Value *V);
Value *convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *TargetTy);
// int device_wrapper(call_no, arg_info, ...) {
// void *desc = __kmpc_host_rpc_get_desc(call_no, num_args, arg_info);
// __kmpc_host_rpc_add_arg(desc, arg1, sizeof(arg1));
// __kmpc_host_rpc_add_arg(desc, arg2, sizeof(arg2));
// ...
// __kmpc_host_rpc_send_and_wait(desc);
// int r = (int)__kmpc_host_rpc_get_ret_val(desc);
// return r;
// }
Function *getDeviceWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
// void host_wrapper(desc) {
// int arg1 = (int)__kmpc_host_rpc_get_arg(desc, 0);
// float arg2 = (float)__kmpc_host_rpc_get_arg(desc, 1);
// char *arg3 = (char *)__kmpc_host_rpc_get_arg(desc, 2);
// ...
// int r = actual_call(arg1, arg2, arg3, ...);
// __kmpc_host_rpc_set_ret_val(ptr(desc, (int64_t)r);
// }
Function *getHostWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI);
bool rewriteWithHostRPC(Function *F);
public:
HostRPC(Module &M) : Context(M.getContext()), M(M) {
assert(&M.getContext() == &HostModule->getContext() &&
"device and host modules have different context");
#define __OMP_TYPE(TYPE) TYPE = Type::get##TYPE(Context)
__OMP_TYPE(Int8PtrTy);
__OMP_TYPE(VoidTy);
__OMP_TYPE(Int32Ty);
__OMP_TYPE(Int64Ty);
#undef __OMP_TYPE
#define __OMP_RTL(_ENUM, MOD, VARARG, RETTY, ...) \
{ \
SmallVector<Type *> Params{__VA_ARGS__}; \
FunctionType *FT = FunctionType::get(RETTY, Params, VARARG); \
Function *F = (MOD).getFunction(#_ENUM); \
if (!F) \
F = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage, \
#_ENUM, (MOD)); \
RFIs[OMPRTL_##_ENUM] = F; \
}
__OMP_RTL(__kmpc_host_rpc_get_desc, M, false, Int8PtrTy, Int32Ty, Int32Ty,
Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_add_arg, M, false, VoidTy, Int8PtrTy, Int64Ty,
Int32Ty)
__OMP_RTL(__kmpc_host_rpc_send_and_wait, M, false, VoidTy, Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_get_ret_val, M, false, Int64Ty, Int8PtrTy)
__OMP_RTL(__kmpc_host_rpc_get_arg, *HostModule, false, Int64Ty, Int8PtrTy,
Int32Ty)
__OMP_RTL(__kmpc_host_rpc_set_ret_val, *HostModule, false, VoidTy,
Int8PtrTy, Int64Ty)
#undef __OMP_RTL
ArgInfoTy = StructType::create({Int64Ty, Int64Ty}, "struct.arg_info_t");
}
bool run();
}; // namespace
Value *HostRPC::convertToInt64Ty(IRBuilder<> &Builder, Value *V) {
Type *T = V->getType();
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreatePtrToInt(V, Int64Ty);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, Int64Ty, false);
if (T->isFloatingPointTy()) {
if (T->isFloatTy())
V = Builder.CreateFPToSI(V, Int32Ty);
return Builder.CreateFPToSI(V, Int64Ty);
}
llvm_unreachable("unknown cast to int64_t");
}
Value *HostRPC::convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *T) {
if (T == Int64Ty)
return V;
if (T->isPointerTy())
return Builder.CreateIntToPtr(V, T);
if (T->isIntegerTy())
return Builder.CreateIntCast(V, T, /* isSigned */ true);
if (T->isFloatingPointTy()) {
if (T->isFloatTy())
V = Builder.CreateIntCast(V, Int32Ty, /* isSigned */ true);
V = Builder.CreateSIToFP(V, T);
return V;
}
llvm_unreachable("unknown cast from int64_t");
}
bool HostRPC::run() {
bool Changed = false;
SmallVector<Function *> WorkList;
for (Function &F : M) {
// If the function is already defined, it definitely does not require RPC.
if (!F.isDeclaration())
continue;
// If it is an internal function, skip it as well.
if (isInternalFunction(F))
continue;
// If there is no use of the function, skip it.
if (F.use_empty())
continue;
WorkList.push_back(&F);
}
if (WorkList.empty())
return Changed;
for (Function *F : WorkList)
Changed |= rewriteWithHostRPC(F);
return Changed;
}
bool HostRPC::rewriteWithHostRPC(Function *F) {
bool Changed = false;
SmallVector<CallInst *> WorkList;
for (User *U : F->users()) {
auto *CI = dyn_cast<CallInst>(U);
if (!CI)
continue;
WorkList.push_back(CI);
}
if (WorkList.empty())
return Changed;
for (CallInst *CI : WorkList) {
CallSiteInfo CSI;
CSI.CI = CI;
unsigned NumArgs = CI->arg_size();
for (unsigned I = 0; I < NumArgs; ++I)
CSI.Params.push_back(CI->getArgOperand(I)->getType());
std::string WrapperName = getWrapperFunctionName(F, CSI);
Function *DeviceWrapperFn = getDeviceWrapperFunction(WrapperName, F, CSI);
Function *HostWrapperFn = getHostWrapperFunction(WrapperName, F, CSI);
int32_t WrapperNumber = -1;
for (unsigned I = 0; I < HostEntryTable.size(); ++I) {
if (HostEntryTable[I] == HostWrapperFn) {
WrapperNumber = I;
break;
}
}
if (WrapperNumber == -1) {
WrapperNumber = HostEntryTable.size();
HostEntryTable.push_back(HostWrapperFn);
}
DataLayout DL = M.getDataLayout();
IRBuilder<> Builder(CI);
auto CheckIfIdentifierPtr = [](const Value *V) {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return false;
Function *Callee = CI->getCalledFunction();
return Callee->getName().startswith("__kmpc_host_rpc_wrapper_");
};
auto CheckIfAlloca = [](const Value *V) {
auto *CI = dyn_cast<CallInst>(V);
if (!CI)
return false;
Function *Callee = CI->getCalledFunction();
return Callee->getName() == "__kmpc_alloc_shared" ||
Callee->getName() == "malloc";
};
SmallVector<HostRPCArgInfo> ArgInfos;
bool IsConstantArgInfo = true;
for (Value *Op : CI->args()) {
if (!Op->getType()->isPointerTy()) {
HostRPCArgInfo AI{
ConstantInt::get(Int64Ty,
OMPHostRPCArgType::OMP_HOST_RPC_ARG_SCALAR),
ConstantInt::getNullValue(Int64Ty)};
ArgInfos.push_back(std::move(AI));
continue;
}
Value *SizeVal = nullptr;
OMPHostRPCArgType ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TOFROM;
SmallVector<const Value *> Objects;
getUnderlyingObjects(Op, Objects);
// TODO: Handle phi node
if (Objects.size() != 1)
llvm_unreachable("we can't handle phi node yet");
auto *Obj = Objects.front();
if (CheckIfIdentifierPtr(Obj)) {
ArgType = OMP_HOST_RPC_ARG_SCALAR;
SizeVal = ConstantInt::getNullValue(Int64Ty);
} else if (CheckIfAlloca(Obj)) {
auto *CI = dyn_cast<CallInst>(Obj);
SizeVal = CI->getOperand(0);
if (!isa<Constant>(SizeVal))
IsConstantArgInfo = false;
} else {
if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
SizeVal = ConstantInt::get(Int64Ty,
DL.getTypeStoreSize(GV->getValueType()));
if (GV->isConstant())
ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TO;
if (GV->isConstant() && GV->hasInitializer()) {
// TODO: If the global variable is contant, we can do some
// optimization.
}
} else {
// TODO: fix that when it occurs
llvm_unreachable("cannot handle unknown type");
}
}
HostRPCArgInfo AI{ConstantInt::get(Int64Ty, ArgType), SizeVal};
ArgInfos.push_back(std::move(AI));
}
Value *ArgInfo = nullptr;
if (!IsConstantArgInfo) {
ArgInfo = Builder.CreateAlloca(
ArgInfoTy, ConstantInt::get(Int64Ty, NumArgs), "arg_info");
for (unsigned I = 0; I < NumArgs; ++I) {
Value *AII = GetElementPtrInst::Create(
ArrayType::get(ArgInfoTy, NumArgs), ArgInfo,
{ConstantInt::getNullValue(Int64Ty), ConstantInt::get(Int64Ty, I)});
Value *AIIType = GetElementPtrInst::Create(
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 0)});
Value *AIISize = GetElementPtrInst::Create(
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 1)});
Builder.Insert(AII);
Builder.Insert(AIIType);
Builder.Insert(AIISize);
Builder.CreateStore(ArgInfos[I].Type, AIIType);
Builder.CreateStore(ArgInfos[I].Size, AIISize);
}
} else {
SmallVector<Constant *> ArgInfoInitVar;
for (auto &AI : ArgInfos) {
auto *CS =
ConstantStruct::get(ArgInfoTy, {AI.Type, cast<Constant>(AI.Size)});
ArgInfoInitVar.push_back(CS);
}
Constant *ArgInfoInit = ConstantArray::get(
ArrayType::get(ArgInfoTy, NumArgs), ArgInfoInitVar);
ArgInfo = new GlobalVariable(
M, ArrayType::get(ArgInfoTy, NumArgs), /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info");
}
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
ArgInfo};
for (Value *Op : CI->args())
Args.push_back(Op);
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
CI->replaceAllUsesWith(NewCall);
CI->eraseFromParent();
}
F->eraseFromParent();
return true;
}
Function *HostRPC::getDeviceWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI) {
Function *WrapperFn = M.getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
// return_type device_wrapper(int32_t call_no, void *arg_info, ...)
SmallVector<Type *> Params{Int32Ty, Int8PtrTy};
Params.append(CSI.Params);
Type *RetTy = F->getReturnType();
FunctionType *FT = FunctionType::get(RetTy, Params, /*isVarArg*/ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::WeakODRLinkage,
WrapperName, M);
// Emit the body of the device wrapper
IRBuilder<> Builder(Context);
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
// skip call_no and arg_info.
constexpr const unsigned NumArgSkipped = 2;
Value *Desc = nullptr;
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
Desc = Builder.CreateCall(
Fn,
{WrapperFn->getArg(0),
ConstantInt::get(Int32Ty, WrapperFn->arg_size() - NumArgSkipped),
WrapperFn->getArg(1)},
"desc");
}
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_add_arg];
for (unsigned I = NumArgSkipped; I < WrapperFn->arg_size(); ++I) {
Value *V = convertToInt64Ty(Builder, WrapperFn->getArg(I));
Builder.CreateCall(Fn, {Desc, V, ConstantInt::get(Int32Ty, I)});
}
}
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_send_and_wait], {Desc});
if (RetTy->isVoidTy()) {
Builder.CreateRetVoid();
return WrapperFn;
}
Value *RetVal =
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_ret_val], {Desc});
if (RetTy != RetVal->getType())
RetVal = convertFromInt64TyTo(Builder, RetVal, RetTy);
Builder.CreateRet(RetVal);
return WrapperFn;
}
Function *HostRPC::getHostWrapperFunction(StringRef WrapperName, Function *F,
CallSiteInfo &CSI) {
Function *WrapperFn = HostModule->getFunction(WrapperName);
if (WrapperFn)
return WrapperFn;
SmallVector<Type *> Params{Int8PtrTy};
FunctionType *FT = FunctionType::get(VoidTy, Params, /* isVarArg */ false);
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage,
WrapperName, *HostModule);
Value *Desc = WrapperFn->getArg(0);
// Emit the body of the host wrapper
IRBuilder<> Builder(Context);
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
Builder.SetInsertPoint(EntryBB);
SmallVector<Value *> Args;
for (unsigned I = 0; I < CSI.CI->arg_size(); ++I) {
Value *V = Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_arg],
{Desc, ConstantInt::get(Int32Ty, I)});
Args.push_back(convertFromInt64TyTo(Builder, V, CSI.Params[I]));
}
// The host callee that will be called eventually by the host wrapper.
Function *HostCallee = HostModule->getFunction(F->getName());
if (!HostCallee)
HostCallee = Function::Create(F->getFunctionType(), F->getLinkage(),
F->getName(), *HostModule);
Value *RetVal = Builder.CreateCall(HostCallee, Args);
RetVal = convertToInt64Ty(Builder, RetVal);
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_set_ret_val], {Desc, RetVal});
Builder.CreateRetVoid();
return WrapperFn;
}
} // namespace
PreservedAnalyses HostRPCPass::run(Module &M, ModuleAnalysisManager &AM) {
printf("[HostRPCPass] HostModule=%p\n", HostModule);
if (!HostModule)
return PreservedAnalyses::all();
HostRPC RPC(M);
bool Changed = RPC.run();
return Changed ? PreservedAnalyses::all() : PreservedAnalyses::none();
}

View File

@@ -100,6 +100,7 @@ set(include_files
set(src_files
${source_directory}/Configuration.cpp
${source_directory}/Debug.cpp
${source_directory}/HostRPC.cpp
${source_directory}/Kernel.cpp
${source_directory}/LibC.cpp
${source_directory}/Mapping.cpp

View File

@@ -351,6 +351,15 @@ int32_t __kmpc_cancel(IdentTy *Loc, int32_t TId, int32_t CancelVal);
int32_t __kmpc_shuffle_int32(int32_t val, int16_t delta, int16_t size);
int64_t __kmpc_shuffle_int64(int64_t val, int16_t delta, int16_t size);
///}
/// Host RPC
///
/// {
void *__kmpc_host_rpc_get_desc(int32_t CallNo, int32_t NumArgs, void *ArgInfo);
void __kmpc_host_rpc_add_arg(void *Desc, int64_t Arg, int32_t ArgNum);
void __kmpc_host_rpc_send_and_wait(void *Desc);
int64_t __kmpc_host_rpc_get_ret_val(void *Desc);
/// }
}
#endif

View File

@@ -0,0 +1,25 @@
//===------- HostRPC.cpp - Implementation of host RPC ------------- C++ ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "Types.h"
#pragma omp begin declare target device_type(nohost)
extern "C" {
void *__kmpc_host_rpc_get_desc(int32_t CallNo, int32_t NumArgs, void *ArgInfo) {
return nullptr;
}
void __kmpc_host_rpc_add_arg(void *Desc, int64_t Arg, int32_t ArgNum) {}
void __kmpc_host_rpc_send_and_wait(void *Desc) {}
int64_t __kmpc_host_rpc_get_ret_val(void *Desc) { return 0; }
}
#pragma omp end declare target