Compare commits
1 Commits
llvm-test-
...
auto-host-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
676d2638f3 |
65
auto-host-rpc/device.ll
Normal file
65
auto-host-rpc/device.ll
Normal file
@@ -0,0 +1,65 @@
|
||||
; ModuleID = 'test-openmp-nvptx64-nvidia-cuda-sm_75.bc'
|
||||
source_filename = "/home/shiltian/Documents/vscode/llvm-project/auto-host-rpc/test.c"
|
||||
target datalayout = "e-i64:64-i128:128-v16:16-v32:32-n16:32:64"
|
||||
target triple = "nvptx64-nvidia-cuda"
|
||||
|
||||
%struct.ddd = type { i32, i32, float }
|
||||
|
||||
@__omp_rtl_debug_kind = weak_odr hidden local_unnamed_addr constant i32 0
|
||||
@__omp_rtl_assume_teams_oversubscription = weak_odr hidden local_unnamed_addr constant i32 0
|
||||
@__omp_rtl_assume_threads_oversubscription = weak_odr hidden local_unnamed_addr constant i32 0
|
||||
@__omp_rtl_assume_no_thread_state = weak_odr hidden local_unnamed_addr constant i32 0
|
||||
@__omp_rtl_assume_no_nested_parallelism = weak_odr hidden local_unnamed_addr constant i32 0
|
||||
@.str = private unnamed_addr constant [9 x i8] c"main.cpp\00", align 1
|
||||
@.str1 = private unnamed_addr constant [2 x i8] c"r\00", align 1
|
||||
@.str2 = private unnamed_addr constant [3 x i8] c"%d\00", align 1
|
||||
@.str3 = private unnamed_addr constant [7 x i8] c"%f%d%s\00", align 1
|
||||
@.str4 = private unnamed_addr constant [6 x i8] c"hello\00", align 1
|
||||
|
||||
; Function Attrs: convergent nounwind
|
||||
define hidden void @foo() local_unnamed_addr #0 {
|
||||
entry:
|
||||
%d = tail call align 16 dereferenceable_or_null(12) ptr @__kmpc_alloc_shared(i64 12) #4
|
||||
%call = tail call noalias ptr @fopen(ptr noundef nonnull @.str, ptr noundef nonnull @.str1) #5
|
||||
%call1 = tail call i32 (ptr, ptr, ...) @fprintf(ptr noundef %call, ptr noundef nonnull @.str2, i32 noundef 6) #5
|
||||
%call2 = tail call i32 (ptr, ptr, ...) @fprintf(ptr noundef %call, ptr noundef nonnull @.str3, double noundef 6.000000e+00, i32 noundef 1, ptr noundef nonnull @.str4) #5
|
||||
%a = getelementptr inbounds %struct.ddd, ptr %d, i64 0, i32 1
|
||||
%call3 = tail call i32 (ptr, ptr, ...) @fscanf(ptr noundef %call, ptr noundef nonnull @.str2, ptr noundef nonnull %a) #5
|
||||
tail call void @__kmpc_free_shared(ptr %d, i64 12)
|
||||
ret void
|
||||
}
|
||||
|
||||
; Function Attrs: nofree nosync nounwind allocsize(0)
|
||||
declare ptr @__kmpc_alloc_shared(i64) local_unnamed_addr #1
|
||||
|
||||
; Function Attrs: convergent
|
||||
declare noalias ptr @fopen(ptr noundef, ptr noundef) local_unnamed_addr #2
|
||||
|
||||
; Function Attrs: convergent
|
||||
declare i32 @fprintf(ptr noundef, ptr noundef, ...) local_unnamed_addr #2
|
||||
|
||||
; Function Attrs: convergent
|
||||
declare i32 @fscanf(ptr noundef, ptr noundef, ...) local_unnamed_addr #2
|
||||
|
||||
; Function Attrs: nosync nounwind
|
||||
declare void @__kmpc_free_shared(ptr allocptr nocapture, i64) local_unnamed_addr #3
|
||||
|
||||
attributes #0 = { convergent nounwind "frame-pointer"="all" "min-legal-vector-width"="0" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="sm_75" "target-features"="+ptx77,+sm_75" }
|
||||
attributes #1 = { nofree nosync nounwind allocsize(0) }
|
||||
attributes #2 = { convergent "frame-pointer"="all" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="sm_75" "target-features"="+ptx77,+sm_75" }
|
||||
attributes #3 = { nosync nounwind }
|
||||
attributes #4 = { nounwind }
|
||||
attributes #5 = { convergent nounwind }
|
||||
|
||||
!llvm.module.flags = !{!0, !1, !2, !3, !4, !5}
|
||||
!llvm.ident = !{!6, !7}
|
||||
!nvvm.annotations = !{}
|
||||
|
||||
!0 = !{i32 2, !"SDK Version", [2 x i32] [i32 11, i32 7]}
|
||||
!1 = !{i32 1, !"wchar_size", i32 4}
|
||||
!2 = !{i32 7, !"openmp", i32 50}
|
||||
!3 = !{i32 7, !"openmp-device", i32 50}
|
||||
!4 = !{i32 8, !"PIC Level", i32 2}
|
||||
!5 = !{i32 7, !"frame-pointer", i32 2}
|
||||
!6 = !{!"clang version 16.0.0"}
|
||||
!7 = !{!"clang version 3.8.0 (tags/RELEASE_380/final)"}
|
||||
560
auto-host-rpc/main.cpp
Normal file
560
auto-host-rpc/main.cpp
Normal file
@@ -0,0 +1,560 @@
|
||||
#include "llvm/ADT/EnumeratedArray.h"
|
||||
#include "llvm/ADT/Triple.h"
|
||||
#include "llvm/Analysis/ValueTracking.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
namespace {
|
||||
static LLVMContext Context;
|
||||
static codegen::RegisterCodeGenFlags RCGF;
|
||||
static constexpr const char *InternalPrefix[] = {"__kmp", "llvm.", "nvm.",
|
||||
"omp_"};
|
||||
|
||||
std::string typeToString(Type *T) {
|
||||
if (T->is16bitFPTy())
|
||||
return "f16";
|
||||
if (T->isFloatTy())
|
||||
return "f32";
|
||||
if (T->isDoubleTy())
|
||||
return "f64";
|
||||
if (T->isPointerTy())
|
||||
return "ptr";
|
||||
if (T->isStructTy())
|
||||
return std::string(T->getStructName());
|
||||
if (T->isIntegerTy())
|
||||
return "i" + std::to_string(T->getIntegerBitWidth());
|
||||
|
||||
llvm_unreachable("unknown type");
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace llvm {
|
||||
|
||||
enum class HostRPCRuntimeFunction {
|
||||
#define __OMPRTL_HOST_RPC(_ENUM) OMPRTL_##_ENUM
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val),
|
||||
__OMPRTL_HOST_RPC(__last),
|
||||
#undef __OMPRTL_HOST_RPC
|
||||
};
|
||||
|
||||
#define __OMPRTL_HOST_RPC(_ENUM) \
|
||||
auto OMPRTL_##_ENUM = HostRPCRuntimeFunction::OMPRTL_##_ENUM;
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val)
|
||||
#undef __OMPRTL_HOST_RPC
|
||||
|
||||
enum OMPHostRPCArgType {
|
||||
// No need to copy.
|
||||
OMP_HOST_RPC_ARG_SCALAR = 0,
|
||||
OMP_HOST_RPC_ARG_PTR = 1,
|
||||
// Copy to host.
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_TO = 2,
|
||||
// Copy to device
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_FROM = 3,
|
||||
// TODO: Do we have a tofrom pointer?
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_TOFROM = 4,
|
||||
};
|
||||
|
||||
// struct HostRPCArgInfo {
|
||||
// // OMPHostRPCArgType
|
||||
// int64_t Type;
|
||||
// int64_t Size;
|
||||
// };
|
||||
|
||||
class AutoHostRPC {
|
||||
LLVMContext &Context;
|
||||
// Device module
|
||||
Module &DM;
|
||||
// Host module
|
||||
Module &HM;
|
||||
// Types
|
||||
Type *Int8PtrTy;
|
||||
Type *VoidTy;
|
||||
Type *Int32Ty;
|
||||
Type *Int64Ty;
|
||||
StructType *ArgInfoTy;
|
||||
|
||||
struct CallSiteInfo {
|
||||
CallInst *CI = nullptr;
|
||||
SmallVector<Type *> Params;
|
||||
};
|
||||
|
||||
struct HostRPCArgInfo {
|
||||
// OMPHostRPCArgType
|
||||
Constant *Type;
|
||||
Value *Size;
|
||||
};
|
||||
|
||||
//
|
||||
SmallVector<Function *> HostEntryTable;
|
||||
|
||||
EnumeratedArray<Function *, HostRPCRuntimeFunction,
|
||||
HostRPCRuntimeFunction::OMPRTL___last>
|
||||
RFIs;
|
||||
|
||||
static std::string getWrapperFunctionName(Function *F, CallSiteInfo &CSI) {
|
||||
std::string Name = "__kmpc_host_rpc_wrapper_" + std::string(F->getName());
|
||||
if (!F->isVarArg())
|
||||
return Name;
|
||||
|
||||
for (unsigned I = F->getFunctionType()->getNumParams();
|
||||
I < CSI.Params.size(); ++I) {
|
||||
Name.push_back('_');
|
||||
Name.append(typeToString(CSI.Params[I]));
|
||||
}
|
||||
|
||||
return Name;
|
||||
}
|
||||
|
||||
static bool isInternalFunction(Function &F) {
|
||||
auto Name = F.getName();
|
||||
|
||||
for (auto P : InternalPrefix)
|
||||
if (Name.startswith(P))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value *convertToInt64Ty(IRBuilder<> &Builder, Value *V);
|
||||
|
||||
Value *convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *TargetTy);
|
||||
|
||||
// int device_wrapper(call_no, arg_info, ...) {
|
||||
// void *desc = __kmpc_host_rpc_get_desc(call_no, num_args, arg_info);
|
||||
// __kmpc_host_rpc_add_arg(desc, arg1, sizeof(arg1));
|
||||
// __kmpc_host_rpc_add_arg(desc, arg2, sizeof(arg2));
|
||||
// ...
|
||||
// __kmpc_host_rpc_send_and_wait(desc);
|
||||
// int r = (int)__kmpc_host_rpc_get_ret_val(desc);
|
||||
// return r;
|
||||
// }
|
||||
Function *getDeviceWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI);
|
||||
|
||||
// void host_wrapper(desc) {
|
||||
// int arg1 = (int)__kmpc_host_rpc_get_arg(desc, 0);
|
||||
// float arg2 = (float)__kmpc_host_rpc_get_arg(desc, 1);
|
||||
// char *arg3 = (char *)__kmpc_host_rpc_get_arg(desc, 2);
|
||||
// ...
|
||||
// int r = actual_call(arg1, arg2, arg3, ...);
|
||||
// __kmpc_host_rpc_set_ret_val(ptr(desc, (int64_t)r);
|
||||
// }
|
||||
Function *getHostWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI);
|
||||
|
||||
bool rewriteWithHostRPC(Function *F);
|
||||
|
||||
public:
|
||||
AutoHostRPC(Module &DeviceModule, Module &HostModule)
|
||||
: Context(DeviceModule.getContext()), DM(DeviceModule), HM(HostModule) {
|
||||
assert(&DeviceModule.getContext() == &HostModule.getContext() &&
|
||||
"device and host modules have different context");
|
||||
|
||||
#define __OMP_TYPE(TYPE) TYPE = Type::get##TYPE(Context)
|
||||
__OMP_TYPE(Int8PtrTy);
|
||||
__OMP_TYPE(VoidTy);
|
||||
__OMP_TYPE(Int32Ty);
|
||||
__OMP_TYPE(Int64Ty);
|
||||
#undef __OMP_TYPE
|
||||
|
||||
#define __OMP_RTL(_ENUM, MOD, VARARG, RETTY, ...) \
|
||||
{ \
|
||||
SmallVector<Type *> Params{__VA_ARGS__}; \
|
||||
FunctionType *FT = FunctionType::get(RETTY, Params, VARARG); \
|
||||
RFIs[OMPRTL_##_ENUM] = Function::Create( \
|
||||
FT, GlobalValue::LinkageTypes::InternalLinkage, #_ENUM, MOD); \
|
||||
}
|
||||
__OMP_RTL(__kmpc_host_rpc_get_desc, DM, false, Int8PtrTy, Int32Ty, Int32Ty,
|
||||
Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_add_arg, DM, false, VoidTy, Int8PtrTy, Int64Ty,
|
||||
Int64Ty)
|
||||
__OMP_RTL(__kmpc_host_rpc_send_and_wait, DM, false, VoidTy, Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_get_ret_val, DM, false, Int64Ty, Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_get_arg, HM, false, Int64Ty, Int8PtrTy, Int32Ty)
|
||||
__OMP_RTL(__kmpc_host_rpc_set_ret_val, HM, false, VoidTy, Int8PtrTy,
|
||||
Int64Ty)
|
||||
#undef __OMP_RTL
|
||||
|
||||
ArgInfoTy = StructType::create({Int64Ty, Int64Ty}, "struct.arg_info_t");
|
||||
}
|
||||
|
||||
bool run();
|
||||
};
|
||||
|
||||
Value *AutoHostRPC::convertToInt64Ty(IRBuilder<> &Builder, Value *V) {
|
||||
Type *T = V->getType();
|
||||
|
||||
if (T == Int64Ty)
|
||||
return V;
|
||||
|
||||
if (T->isPointerTy())
|
||||
return Builder.CreatePtrToInt(V, Int64Ty);
|
||||
|
||||
if (T->isIntegerTy())
|
||||
return Builder.CreateIntCast(V, Int64Ty, false);
|
||||
|
||||
if (T->isFloatingPointTy()) {
|
||||
if (T->isFloatTy())
|
||||
V = Builder.CreateFPToSI(V, Int32Ty);
|
||||
return Builder.CreateFPToSI(V, Int64Ty);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cast to int64_t");
|
||||
}
|
||||
|
||||
Value *AutoHostRPC::convertFromInt64TyTo(IRBuilder<> &Builder, Value *V,
|
||||
Type *T) {
|
||||
if (T == Int64Ty)
|
||||
return V;
|
||||
|
||||
if (T->isPointerTy())
|
||||
return Builder.CreateIntToPtr(V, T);
|
||||
|
||||
if (T->isIntegerTy())
|
||||
return Builder.CreateIntCast(V, T, /* isSigned */ true);
|
||||
|
||||
if (T->isFloatingPointTy()) {
|
||||
if (T->isFloatTy())
|
||||
V = Builder.CreateIntCast(V, Int32Ty, /* isSigned */ true);
|
||||
V = Builder.CreateSIToFP(V, T);
|
||||
return V;
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cast from int64_t");
|
||||
}
|
||||
|
||||
bool AutoHostRPC::run() {
|
||||
bool Changed = false;
|
||||
|
||||
SmallVector<Function *> WorkList;
|
||||
|
||||
for (Function &F : DM) {
|
||||
// If the function is already defined, it definitely does not require RPC.
|
||||
if (!F.isDeclaration())
|
||||
continue;
|
||||
|
||||
// If it is an internal function, skip it as well.
|
||||
if (isInternalFunction(F))
|
||||
continue;
|
||||
|
||||
// If there is no use of the function, skip it.
|
||||
if (F.use_empty())
|
||||
continue;
|
||||
|
||||
WorkList.push_back(&F);
|
||||
}
|
||||
|
||||
if (WorkList.empty())
|
||||
return Changed;
|
||||
|
||||
for (Function *F : WorkList)
|
||||
Changed |= rewriteWithHostRPC(F);
|
||||
|
||||
return Changed;
|
||||
}
|
||||
|
||||
bool AutoHostRPC::rewriteWithHostRPC(Function *F) {
|
||||
bool Changed = false;
|
||||
|
||||
SmallVector<CallInst *> WorkList;
|
||||
|
||||
for (User *U : F->users()) {
|
||||
auto *CI = dyn_cast<CallInst>(U);
|
||||
if (!CI)
|
||||
continue;
|
||||
WorkList.push_back(CI);
|
||||
}
|
||||
|
||||
if (WorkList.empty())
|
||||
return Changed;
|
||||
|
||||
for (CallInst *CI : WorkList) {
|
||||
CallSiteInfo CSI;
|
||||
CSI.CI = CI;
|
||||
unsigned NumArgs = CI->arg_size();
|
||||
for (unsigned I = 0; I < NumArgs; ++I)
|
||||
CSI.Params.push_back(CI->getArgOperand(I)->getType());
|
||||
|
||||
std::string WrapperName = getWrapperFunctionName(F, CSI);
|
||||
Function *DeviceWrapperFn = getDeviceWrapperFunction(WrapperName, F, CSI);
|
||||
Function *HostWrapperFn = getHostWrapperFunction(WrapperName, F, CSI);
|
||||
|
||||
int32_t WrapperNumber = -1;
|
||||
for (unsigned I = 0; I < HostEntryTable.size(); ++I) {
|
||||
if (HostEntryTable[I] == HostWrapperFn) {
|
||||
WrapperNumber = I;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (WrapperNumber == -1) {
|
||||
WrapperNumber = HostEntryTable.size();
|
||||
HostEntryTable.push_back(HostWrapperFn);
|
||||
}
|
||||
|
||||
DataLayout DL = DM.getDataLayout();
|
||||
IRBuilder<> Builder(CI);
|
||||
|
||||
auto CheckIfIdentifierPtr = [](const Value *V) {
|
||||
auto *CI = dyn_cast<CallInst>(V);
|
||||
if (!CI)
|
||||
return false;
|
||||
Function *Callee = CI->getCalledFunction();
|
||||
return Callee->getName().startswith("__kmpc_host_rpc_wrapper_");
|
||||
};
|
||||
|
||||
auto CheckIfAlloca = [](const Value *V) {
|
||||
auto *CI = dyn_cast<CallInst>(V);
|
||||
if (!CI)
|
||||
return false;
|
||||
Function *Callee = CI->getCalledFunction();
|
||||
return Callee->getName() == "__kmpc_alloc_shared" ||
|
||||
Callee->getName() == "malloc";
|
||||
};
|
||||
|
||||
SmallVector<HostRPCArgInfo> ArgInfos;
|
||||
bool IsConstantArgInfo = true;
|
||||
|
||||
for (Value *Op : CI->args()) {
|
||||
if (!Op->getType()->isPointerTy()) {
|
||||
HostRPCArgInfo AI{
|
||||
ConstantInt::get(Int64Ty,
|
||||
OMPHostRPCArgType::OMP_HOST_RPC_ARG_SCALAR),
|
||||
ConstantInt::getNullValue(Int64Ty)};
|
||||
ArgInfos.push_back(std::move(AI));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value *SizeVal = nullptr;
|
||||
OMPHostRPCArgType ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TOFROM;
|
||||
|
||||
SmallVector<const Value *> Objects;
|
||||
getUnderlyingObjects(Op, Objects);
|
||||
|
||||
// TODO: Handle phi node
|
||||
if (Objects.size() != 1)
|
||||
llvm_unreachable("we can't handle phi node yet");
|
||||
|
||||
auto *Obj = Objects.front();
|
||||
if (CheckIfIdentifierPtr(Obj)) {
|
||||
ArgType = OMP_HOST_RPC_ARG_SCALAR;
|
||||
SizeVal = ConstantInt::getNullValue(Int64Ty);
|
||||
} else if (CheckIfAlloca(Obj)) {
|
||||
auto *CI = dyn_cast<CallInst>(Obj);
|
||||
SizeVal = CI->getOperand(0);
|
||||
if (!isa<Constant>(SizeVal))
|
||||
IsConstantArgInfo = false;
|
||||
} else {
|
||||
if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
|
||||
SizeVal = ConstantInt::get(Int64Ty,
|
||||
DL.getTypeStoreSize(GV->getValueType()));
|
||||
if (GV->isConstant())
|
||||
ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TO;
|
||||
if (GV->isConstant() && GV->hasInitializer()) {
|
||||
// TODO: If the global variable is contant, we can do some
|
||||
// optimization.
|
||||
}
|
||||
} else {
|
||||
// TODO: fix that when it occurs
|
||||
llvm_unreachable("cannot handle unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
HostRPCArgInfo AI{ConstantInt::get(Int64Ty, ArgType), SizeVal};
|
||||
ArgInfos.push_back(std::move(AI));
|
||||
}
|
||||
|
||||
Value *ArgInfo = nullptr;
|
||||
|
||||
if (!IsConstantArgInfo) {
|
||||
ArgInfo = Builder.CreateAlloca(
|
||||
ArgInfoTy, ConstantInt::get(Int64Ty, NumArgs), "arg_info");
|
||||
for (unsigned I = 0; I < NumArgs; ++I) {
|
||||
Value *AII = GetElementPtrInst::Create(
|
||||
ArrayType::get(ArgInfoTy, NumArgs), ArgInfo,
|
||||
{ConstantInt::getNullValue(Int64Ty), ConstantInt::get(Int64Ty, I)});
|
||||
Value *AIIType = GetElementPtrInst::Create(
|
||||
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 0)});
|
||||
Value *AIISize = GetElementPtrInst::Create(
|
||||
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 1)});
|
||||
Builder.Insert(AII);
|
||||
Builder.Insert(AIIType);
|
||||
Builder.Insert(AIISize);
|
||||
Builder.CreateStore(ArgInfos[I].Type, AIIType);
|
||||
Builder.CreateStore(ArgInfos[I].Size, AIISize);
|
||||
}
|
||||
} else {
|
||||
SmallVector<Constant *> ArgInfoInitVar;
|
||||
for (auto &AI : ArgInfos) {
|
||||
auto *CS =
|
||||
ConstantStruct::get(ArgInfoTy, {AI.Type, cast<Constant>(AI.Size)});
|
||||
ArgInfoInitVar.push_back(CS);
|
||||
}
|
||||
Constant *ArgInfoInit = ConstantArray::get(
|
||||
ArrayType::get(ArgInfoTy, NumArgs), ArgInfoInitVar);
|
||||
ArgInfo = new GlobalVariable(
|
||||
DM, ArrayType::get(ArgInfoTy, NumArgs), /* isConstant */ true,
|
||||
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info");
|
||||
}
|
||||
|
||||
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
|
||||
ArgInfo};
|
||||
for (Value *Op : CI->args())
|
||||
Args.push_back(Op);
|
||||
|
||||
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
|
||||
|
||||
CI->replaceAllUsesWith(NewCall);
|
||||
CI->eraseFromParent();
|
||||
}
|
||||
|
||||
F->eraseFromParent();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Function *AutoHostRPC::getDeviceWrapperFunction(StringRef WrapperName,
|
||||
Function *F,
|
||||
CallSiteInfo &CSI) {
|
||||
Function *WrapperFn = DM.getFunction(WrapperName);
|
||||
if (WrapperFn)
|
||||
return WrapperFn;
|
||||
|
||||
// return_type device_wrapper(int32_t call_no, void *arg_info, ...)
|
||||
SmallVector<Type *> Params{Int32Ty, Int8PtrTy};
|
||||
Params.append(CSI.Params);
|
||||
|
||||
Type *RetTy = F->getReturnType();
|
||||
|
||||
FunctionType *FT = FunctionType::get(RetTy, Params, /*isVarArg*/ false);
|
||||
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::WeakODRLinkage,
|
||||
WrapperName, DM);
|
||||
|
||||
// Emit the body of the device wrapper
|
||||
IRBuilder<> Builder(Context);
|
||||
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
|
||||
Builder.SetInsertPoint(EntryBB);
|
||||
|
||||
// skip call_no and arg_info.
|
||||
constexpr const unsigned NumArgSkipped = 2;
|
||||
|
||||
Value *Desc = nullptr;
|
||||
{
|
||||
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
|
||||
Desc = Builder.CreateCall(
|
||||
Fn,
|
||||
{WrapperFn->getArg(0),
|
||||
ConstantInt::get(Int32Ty, WrapperFn->arg_size() - NumArgSkipped),
|
||||
WrapperFn->getArg(1)},
|
||||
"desc");
|
||||
}
|
||||
|
||||
{
|
||||
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_add_arg];
|
||||
for (unsigned I = NumArgSkipped; I < WrapperFn->arg_size(); ++I) {
|
||||
Value *V = convertToInt64Ty(Builder, WrapperFn->getArg(I));
|
||||
Builder.CreateCall(Fn, {Desc, V, ConstantInt::getNullValue(Int64Ty)});
|
||||
}
|
||||
}
|
||||
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_send_and_wait], {Desc});
|
||||
|
||||
if (RetTy->isVoidTy()) {
|
||||
Builder.CreateRetVoid();
|
||||
return WrapperFn;
|
||||
}
|
||||
|
||||
Value *RetVal =
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_ret_val], {Desc});
|
||||
if (RetTy != RetVal->getType())
|
||||
RetVal = convertFromInt64TyTo(Builder, RetVal, RetTy);
|
||||
|
||||
Builder.CreateRet(RetVal);
|
||||
|
||||
return WrapperFn;
|
||||
}
|
||||
|
||||
Function *AutoHostRPC::getHostWrapperFunction(StringRef WrapperName,
|
||||
Function *F, CallSiteInfo &CSI) {
|
||||
Function *WrapperFn = HM.getFunction(WrapperName);
|
||||
if (WrapperFn)
|
||||
return WrapperFn;
|
||||
|
||||
SmallVector<Type *> Params{Int8PtrTy};
|
||||
FunctionType *FT = FunctionType::get(VoidTy, Params, /* isVarArg */ false);
|
||||
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage,
|
||||
WrapperName, HM);
|
||||
|
||||
Value *Desc = WrapperFn->getArg(0);
|
||||
|
||||
// Emit the body of the host wrapper
|
||||
IRBuilder<> Builder(Context);
|
||||
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
|
||||
Builder.SetInsertPoint(EntryBB);
|
||||
|
||||
SmallVector<Value *> Args;
|
||||
for (unsigned I = 0; I < CSI.CI->arg_size(); ++I) {
|
||||
Value *V = Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_arg],
|
||||
{Desc, ConstantInt::get(Int32Ty, I)});
|
||||
Args.push_back(convertFromInt64TyTo(Builder, V, CSI.Params[I]));
|
||||
}
|
||||
|
||||
// The host callee that will be called eventually by the host wrapper.
|
||||
Function *HostCallee = HM.getFunction(F->getName());
|
||||
if (!HostCallee)
|
||||
HostCallee = Function::Create(F->getFunctionType(), F->getLinkage(),
|
||||
F->getName(), HM);
|
||||
|
||||
Value *RetVal = Builder.CreateCall(HostCallee, Args);
|
||||
RetVal = convertToInt64Ty(Builder, RetVal);
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_set_ret_val], {Desc, RetVal});
|
||||
Builder.CreateRetVoid();
|
||||
|
||||
return WrapperFn;
|
||||
}
|
||||
|
||||
} // namespace llvm
|
||||
|
||||
int main(int argc, char *argv[]) {
|
||||
InitializeAllTargets();
|
||||
InitializeAllTargetMCs();
|
||||
InitializeAllAsmPrinters();
|
||||
InitializeAllAsmParsers();
|
||||
|
||||
SMDiagnostic Err;
|
||||
std::unique_ptr<Module> DM = parseIRFile("device.ll", Err, Context);
|
||||
if (!DM)
|
||||
return 1;
|
||||
Module HM("host-rpc.bc", Context);
|
||||
// get the right target triple
|
||||
HM.setTargetTriple(Triple::normalize("x86-64"));
|
||||
|
||||
AutoHostRPC RPC(*DM, HM);
|
||||
(void)RPC.run();
|
||||
|
||||
DM->dump();
|
||||
// HM.dump();
|
||||
|
||||
return 0;
|
||||
}
|
||||
19
auto-host-rpc/test.c
Normal file
19
auto-host-rpc/test.c
Normal file
@@ -0,0 +1,19 @@
|
||||
#include <stdio.h>
|
||||
|
||||
#pragma omp begin declare target device_type(nohost)
|
||||
|
||||
struct ddd {
|
||||
int num;
|
||||
int a;
|
||||
float b;
|
||||
};
|
||||
|
||||
void foo() {
|
||||
FILE *fp = fopen("main.cpp", "r");
|
||||
struct ddd d;
|
||||
fprintf(fp, "%d", 6);
|
||||
fprintf(fp, "%f%d%s", 6.0f, 1, "hello");
|
||||
fscanf(fp, "%d", &d.a);
|
||||
}
|
||||
|
||||
#pragma omp end declare target
|
||||
@@ -57,6 +57,8 @@ using namespace llvm;
|
||||
using namespace llvm::opt;
|
||||
using namespace llvm::object;
|
||||
|
||||
extern "C" int printf(const char *, ...);
|
||||
|
||||
/// Path of the current binary.
|
||||
static const char *LinkerExecutable;
|
||||
|
||||
@@ -86,6 +88,8 @@ static std::atomic<bool> LTOError;
|
||||
|
||||
using OffloadingImage = OffloadBinary::OffloadingImage;
|
||||
|
||||
Module *HostModule = nullptr;
|
||||
|
||||
namespace llvm {
|
||||
// Provide DenseMapInfo so that OffloadKind can be used in a DenseMap.
|
||||
template <> struct DenseMapInfo<OffloadKind> {
|
||||
@@ -774,6 +778,14 @@ Error linkBitcodeFiles(SmallVectorImpl<OffloadFile> &InputFiles,
|
||||
? createLTO(Args, Features, OutputBitcode)
|
||||
: createLTO(Args, Features);
|
||||
|
||||
LLVMContext &Ctx = LTOBackend->getContext();
|
||||
std::unique_ptr<Module> HostModulePtr =
|
||||
std::make_unique<Module>("host-rpc.bc", Ctx);
|
||||
HostModule = HostModulePtr.get();
|
||||
HostModulePtr->setTargetTriple(
|
||||
Args.getLastArgValue(OPT_host_triple_EQ, sys::getDefaultTargetTriple()));
|
||||
printf("[linker-wrapper] HostModule=%p\n", HostModule);
|
||||
|
||||
// We need to resolve the symbols so the LTO backend knows which symbols need
|
||||
// to be kept or can be internalized. This is a simplified symbol resolution
|
||||
// scheme to approximate the full resolution a linker would do.
|
||||
@@ -863,6 +875,10 @@ Error linkBitcodeFiles(SmallVectorImpl<OffloadFile> &InputFiles,
|
||||
if (Error Err = LTOBackend->run(AddStream))
|
||||
return Err;
|
||||
|
||||
// Reset the HostModule pointer.
|
||||
HostModulePtr.reset();
|
||||
HostModule = nullptr;
|
||||
|
||||
if (LTOError)
|
||||
return createStringError(inconvertibleErrorCode(),
|
||||
"Errors encountered inside the LTO pipeline.");
|
||||
|
||||
@@ -281,6 +281,9 @@ public:
|
||||
/// by LTO but might not be visible from bitcode symbol table.
|
||||
static ArrayRef<const char*> getRuntimeLibcallSymbols();
|
||||
|
||||
/// Returns the context.
|
||||
LLVMContext &getContext() { return RegularLTO.Ctx; }
|
||||
|
||||
private:
|
||||
Config Conf;
|
||||
|
||||
|
||||
24
llvm/include/llvm/Transforms/Utils/HostRPC.h
Normal file
24
llvm/include/llvm/Transforms/Utils/HostRPC.h
Normal file
@@ -0,0 +1,24 @@
|
||||
//===- Transform/Utils/HostRPC.h - Code of automatic host rpc ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef LLVM_TRANSFORMS_UTILS_HOSTRPC_H
|
||||
#define LLVM_TRANSFORMS_UTILS_HOSTRPC_H
|
||||
|
||||
#include "llvm/IR/PassManager.h"
|
||||
|
||||
namespace llvm {
|
||||
class HostRPCPass : public PassInfoMixin<HostRPCPass> {
|
||||
public:
|
||||
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
|
||||
};
|
||||
} // namespace llvm
|
||||
|
||||
#endif // LLVM_TRANSFORMS_UTILS_HOSTRPC_H
|
||||
@@ -228,6 +228,7 @@
|
||||
#include "llvm/Transforms/Utils/EntryExitInstrumenter.h"
|
||||
#include "llvm/Transforms/Utils/FixIrreducible.h"
|
||||
#include "llvm/Transforms/Utils/HelloWorld.h"
|
||||
#include "llvm/Transforms/Utils/HostRPC.h"
|
||||
#include "llvm/Transforms/Utils/InjectTLIMappings.h"
|
||||
#include "llvm/Transforms/Utils/InstructionNamer.h"
|
||||
#include "llvm/Transforms/Utils/LCSSA.h"
|
||||
|
||||
@@ -118,6 +118,7 @@
|
||||
#include "llvm/Transforms/Utils/AddDiscriminators.h"
|
||||
#include "llvm/Transforms/Utils/AssumeBundleBuilder.h"
|
||||
#include "llvm/Transforms/Utils/CanonicalizeAliases.h"
|
||||
#include "llvm/Transforms/Utils/HostRPC.h"
|
||||
#include "llvm/Transforms/Utils/InjectTLIMappings.h"
|
||||
#include "llvm/Transforms/Utils/LibCallsShrinkWrap.h"
|
||||
#include "llvm/Transforms/Utils/Mem2Reg.h"
|
||||
@@ -190,6 +191,9 @@ static cl::opt<bool> EnableGlobalAnalyses(
|
||||
"enable-global-analyses", cl::init(true), cl::Hidden,
|
||||
cl::desc("Enable inter-procedural analyses"));
|
||||
|
||||
cl::opt<bool> EnableHostRPC("enable-host-rpc", cl::init(false), cl::Hidden,
|
||||
cl::desc("Enable HostRPC pass"));
|
||||
|
||||
PipelineTuningOptions::PipelineTuningOptions() {
|
||||
LoopInterleaving = true;
|
||||
LoopVectorization = true;
|
||||
@@ -918,6 +922,9 @@ PassBuilder::buildModuleSimplificationPipeline(OptimizationLevel Level,
|
||||
PGOIndirectCallPromotion(true /* IsInLTO */, true /* SamplePGO */));
|
||||
}
|
||||
|
||||
if (EnableHostRPC)
|
||||
MPM.addPass(HostRPCPass());
|
||||
|
||||
// Try to perform OpenMP specific optimizations on the module. This is a
|
||||
// (quick!) no-op if there are no OpenMP runtime calls present in the module.
|
||||
if (Level != OptimizationLevel::O0)
|
||||
|
||||
@@ -127,6 +127,7 @@ MODULE_PASS("sanmd-module", SanitizerBinaryMetadataPass())
|
||||
MODULE_PASS("memprof-module", ModuleMemProfilerPass())
|
||||
MODULE_PASS("poison-checking", PoisonCheckingPass())
|
||||
MODULE_PASS("pseudo-probe-update", PseudoProbeUpdatePass())
|
||||
MODULE_PASS("host-rpc", HostRPCPass())
|
||||
#undef MODULE_PASS
|
||||
|
||||
#ifndef MODULE_PASS_WITH_PARAMS
|
||||
|
||||
@@ -29,6 +29,7 @@ add_llvm_component_library(LLVMTransformUtils
|
||||
GlobalStatus.cpp
|
||||
GuardUtils.cpp
|
||||
HelloWorld.cpp
|
||||
HostRPC.cpp
|
||||
InlineFunction.cpp
|
||||
InjectTLIMappings.cpp
|
||||
InstructionNamer.cpp
|
||||
|
||||
553
llvm/lib/Transforms/Utils/HostRPC.cpp
Normal file
553
llvm/lib/Transforms/Utils/HostRPC.cpp
Normal file
@@ -0,0 +1,553 @@
|
||||
//===- Transform/Utils/HostRPC.h - Code of automatic host rpc ---*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/Transforms/Utils/HostRPC.h"
|
||||
|
||||
#include "llvm/ADT/EnumeratedArray.h"
|
||||
#include "llvm/ADT/Triple.h"
|
||||
#include "llvm/Analysis/ValueTracking.h"
|
||||
#include "llvm/CodeGen/CommandFlags.h"
|
||||
#include "llvm/IR/BasicBlock.h"
|
||||
#include "llvm/IR/IRBuilder.h"
|
||||
#include "llvm/IR/Instructions.h"
|
||||
#include "llvm/IR/LLVMContext.h"
|
||||
#include "llvm/IR/Module.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/TargetSelect.h"
|
||||
#include "llvm/Target/TargetOptions.h"
|
||||
|
||||
using namespace llvm;
|
||||
|
||||
#define DEBUG_TYPE "host-rpc"
|
||||
|
||||
extern "C" int printf(const char *, ...);
|
||||
|
||||
__attribute__((weak)) Module *HostModule = nullptr;
|
||||
|
||||
namespace {
|
||||
// TODO: Remove those functions implemented in device runtime.
|
||||
static constexpr const char *InternalPrefix[] = {
|
||||
"__kmp", "llvm.", "nvm.", "omp_", "vprintf", "malloc", "free"};
|
||||
|
||||
std::string typeToString(Type *T) {
|
||||
if (T->is16bitFPTy())
|
||||
return "f16";
|
||||
if (T->isFloatTy())
|
||||
return "f32";
|
||||
if (T->isDoubleTy())
|
||||
return "f64";
|
||||
if (T->isPointerTy())
|
||||
return "ptr";
|
||||
if (T->isStructTy())
|
||||
return std::string(T->getStructName());
|
||||
if (T->isIntegerTy())
|
||||
return "i" + std::to_string(T->getIntegerBitWidth());
|
||||
|
||||
llvm_unreachable("unknown type");
|
||||
}
|
||||
|
||||
enum class HostRPCRuntimeFunction {
|
||||
#define __OMPRTL_HOST_RPC(_ENUM) OMPRTL_##_ENUM
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val),
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val),
|
||||
__OMPRTL_HOST_RPC(__last),
|
||||
#undef __OMPRTL_HOST_RPC
|
||||
};
|
||||
|
||||
#define __OMPRTL_HOST_RPC(_ENUM) \
|
||||
auto OMPRTL_##_ENUM = HostRPCRuntimeFunction::OMPRTL_##_ENUM;
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_desc)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_add_arg)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_arg)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_send_and_wait)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_get_ret_val)
|
||||
__OMPRTL_HOST_RPC(__kmpc_host_rpc_set_ret_val)
|
||||
#undef __OMPRTL_HOST_RPC
|
||||
|
||||
enum OMPHostRPCArgType {
|
||||
// No need to copy.
|
||||
OMP_HOST_RPC_ARG_SCALAR = 0,
|
||||
OMP_HOST_RPC_ARG_PTR = 1,
|
||||
// Copy to host.
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_TO = 2,
|
||||
// Copy to device
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_FROM = 3,
|
||||
// TODO: Do we have a tofrom pointer?
|
||||
OMP_HOST_RPC_ARG_PTR_COPY_TOFROM = 4,
|
||||
};
|
||||
|
||||
class HostRPC {
|
||||
LLVMContext &Context;
|
||||
// Device module
|
||||
Module &M;
|
||||
// Types
|
||||
Type *Int8PtrTy;
|
||||
Type *VoidTy;
|
||||
Type *Int32Ty;
|
||||
Type *Int64Ty;
|
||||
StructType *ArgInfoTy;
|
||||
|
||||
struct CallSiteInfo {
|
||||
CallInst *CI = nullptr;
|
||||
SmallVector<Type *> Params;
|
||||
};
|
||||
|
||||
struct HostRPCArgInfo {
|
||||
// OMPHostRPCArgType
|
||||
Constant *Type;
|
||||
Value *Size;
|
||||
};
|
||||
|
||||
//
|
||||
SmallVector<Function *> HostEntryTable;
|
||||
|
||||
EnumeratedArray<Function *, HostRPCRuntimeFunction,
|
||||
HostRPCRuntimeFunction::OMPRTL___last>
|
||||
RFIs;
|
||||
|
||||
static std::string getWrapperFunctionName(Function *F, CallSiteInfo &CSI) {
|
||||
std::string Name = "__kmpc_host_rpc_wrapper_" + std::string(F->getName());
|
||||
if (!F->isVarArg())
|
||||
return Name;
|
||||
|
||||
for (unsigned I = F->getFunctionType()->getNumParams();
|
||||
I < CSI.Params.size(); ++I) {
|
||||
Name.push_back('_');
|
||||
Name.append(typeToString(CSI.Params[I]));
|
||||
}
|
||||
|
||||
return Name;
|
||||
}
|
||||
|
||||
static bool isInternalFunction(Function &F) {
|
||||
auto Name = F.getName();
|
||||
|
||||
for (auto P : InternalPrefix)
|
||||
if (Name.startswith(P))
|
||||
return true;
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
Value *convertToInt64Ty(IRBuilder<> &Builder, Value *V);
|
||||
|
||||
Value *convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *TargetTy);
|
||||
|
||||
// int device_wrapper(call_no, arg_info, ...) {
|
||||
// void *desc = __kmpc_host_rpc_get_desc(call_no, num_args, arg_info);
|
||||
// __kmpc_host_rpc_add_arg(desc, arg1, sizeof(arg1));
|
||||
// __kmpc_host_rpc_add_arg(desc, arg2, sizeof(arg2));
|
||||
// ...
|
||||
// __kmpc_host_rpc_send_and_wait(desc);
|
||||
// int r = (int)__kmpc_host_rpc_get_ret_val(desc);
|
||||
// return r;
|
||||
// }
|
||||
Function *getDeviceWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI);
|
||||
|
||||
// void host_wrapper(desc) {
|
||||
// int arg1 = (int)__kmpc_host_rpc_get_arg(desc, 0);
|
||||
// float arg2 = (float)__kmpc_host_rpc_get_arg(desc, 1);
|
||||
// char *arg3 = (char *)__kmpc_host_rpc_get_arg(desc, 2);
|
||||
// ...
|
||||
// int r = actual_call(arg1, arg2, arg3, ...);
|
||||
// __kmpc_host_rpc_set_ret_val(ptr(desc, (int64_t)r);
|
||||
// }
|
||||
Function *getHostWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI);
|
||||
|
||||
bool rewriteWithHostRPC(Function *F);
|
||||
|
||||
public:
|
||||
HostRPC(Module &M) : Context(M.getContext()), M(M) {
|
||||
assert(&M.getContext() == &HostModule->getContext() &&
|
||||
"device and host modules have different context");
|
||||
|
||||
#define __OMP_TYPE(TYPE) TYPE = Type::get##TYPE(Context)
|
||||
__OMP_TYPE(Int8PtrTy);
|
||||
__OMP_TYPE(VoidTy);
|
||||
__OMP_TYPE(Int32Ty);
|
||||
__OMP_TYPE(Int64Ty);
|
||||
#undef __OMP_TYPE
|
||||
|
||||
#define __OMP_RTL(_ENUM, MOD, VARARG, RETTY, ...) \
|
||||
{ \
|
||||
SmallVector<Type *> Params{__VA_ARGS__}; \
|
||||
FunctionType *FT = FunctionType::get(RETTY, Params, VARARG); \
|
||||
Function *F = (MOD).getFunction(#_ENUM); \
|
||||
if (!F) \
|
||||
F = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage, \
|
||||
#_ENUM, (MOD)); \
|
||||
RFIs[OMPRTL_##_ENUM] = F; \
|
||||
}
|
||||
__OMP_RTL(__kmpc_host_rpc_get_desc, M, false, Int8PtrTy, Int32Ty, Int32Ty,
|
||||
Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_add_arg, M, false, VoidTy, Int8PtrTy, Int64Ty,
|
||||
Int32Ty)
|
||||
__OMP_RTL(__kmpc_host_rpc_send_and_wait, M, false, VoidTy, Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_get_ret_val, M, false, Int64Ty, Int8PtrTy)
|
||||
__OMP_RTL(__kmpc_host_rpc_get_arg, *HostModule, false, Int64Ty, Int8PtrTy,
|
||||
Int32Ty)
|
||||
__OMP_RTL(__kmpc_host_rpc_set_ret_val, *HostModule, false, VoidTy,
|
||||
Int8PtrTy, Int64Ty)
|
||||
#undef __OMP_RTL
|
||||
|
||||
ArgInfoTy = StructType::create({Int64Ty, Int64Ty}, "struct.arg_info_t");
|
||||
}
|
||||
|
||||
bool run();
|
||||
}; // namespace
|
||||
|
||||
Value *HostRPC::convertToInt64Ty(IRBuilder<> &Builder, Value *V) {
|
||||
Type *T = V->getType();
|
||||
|
||||
if (T == Int64Ty)
|
||||
return V;
|
||||
|
||||
if (T->isPointerTy())
|
||||
return Builder.CreatePtrToInt(V, Int64Ty);
|
||||
|
||||
if (T->isIntegerTy())
|
||||
return Builder.CreateIntCast(V, Int64Ty, false);
|
||||
|
||||
if (T->isFloatingPointTy()) {
|
||||
if (T->isFloatTy())
|
||||
V = Builder.CreateFPToSI(V, Int32Ty);
|
||||
return Builder.CreateFPToSI(V, Int64Ty);
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cast to int64_t");
|
||||
}
|
||||
|
||||
Value *HostRPC::convertFromInt64TyTo(IRBuilder<> &Builder, Value *V, Type *T) {
|
||||
if (T == Int64Ty)
|
||||
return V;
|
||||
|
||||
if (T->isPointerTy())
|
||||
return Builder.CreateIntToPtr(V, T);
|
||||
|
||||
if (T->isIntegerTy())
|
||||
return Builder.CreateIntCast(V, T, /* isSigned */ true);
|
||||
|
||||
if (T->isFloatingPointTy()) {
|
||||
if (T->isFloatTy())
|
||||
V = Builder.CreateIntCast(V, Int32Ty, /* isSigned */ true);
|
||||
V = Builder.CreateSIToFP(V, T);
|
||||
return V;
|
||||
}
|
||||
|
||||
llvm_unreachable("unknown cast from int64_t");
|
||||
}
|
||||
|
||||
bool HostRPC::run() {
|
||||
bool Changed = false;
|
||||
|
||||
SmallVector<Function *> WorkList;
|
||||
|
||||
for (Function &F : M) {
|
||||
// If the function is already defined, it definitely does not require RPC.
|
||||
if (!F.isDeclaration())
|
||||
continue;
|
||||
|
||||
// If it is an internal function, skip it as well.
|
||||
if (isInternalFunction(F))
|
||||
continue;
|
||||
|
||||
// If there is no use of the function, skip it.
|
||||
if (F.use_empty())
|
||||
continue;
|
||||
|
||||
WorkList.push_back(&F);
|
||||
}
|
||||
|
||||
if (WorkList.empty())
|
||||
return Changed;
|
||||
|
||||
for (Function *F : WorkList)
|
||||
Changed |= rewriteWithHostRPC(F);
|
||||
|
||||
return Changed;
|
||||
}
|
||||
|
||||
bool HostRPC::rewriteWithHostRPC(Function *F) {
|
||||
bool Changed = false;
|
||||
|
||||
SmallVector<CallInst *> WorkList;
|
||||
|
||||
for (User *U : F->users()) {
|
||||
auto *CI = dyn_cast<CallInst>(U);
|
||||
if (!CI)
|
||||
continue;
|
||||
WorkList.push_back(CI);
|
||||
}
|
||||
|
||||
if (WorkList.empty())
|
||||
return Changed;
|
||||
|
||||
for (CallInst *CI : WorkList) {
|
||||
CallSiteInfo CSI;
|
||||
CSI.CI = CI;
|
||||
unsigned NumArgs = CI->arg_size();
|
||||
for (unsigned I = 0; I < NumArgs; ++I)
|
||||
CSI.Params.push_back(CI->getArgOperand(I)->getType());
|
||||
|
||||
std::string WrapperName = getWrapperFunctionName(F, CSI);
|
||||
Function *DeviceWrapperFn = getDeviceWrapperFunction(WrapperName, F, CSI);
|
||||
Function *HostWrapperFn = getHostWrapperFunction(WrapperName, F, CSI);
|
||||
|
||||
int32_t WrapperNumber = -1;
|
||||
for (unsigned I = 0; I < HostEntryTable.size(); ++I) {
|
||||
if (HostEntryTable[I] == HostWrapperFn) {
|
||||
WrapperNumber = I;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (WrapperNumber == -1) {
|
||||
WrapperNumber = HostEntryTable.size();
|
||||
HostEntryTable.push_back(HostWrapperFn);
|
||||
}
|
||||
|
||||
DataLayout DL = M.getDataLayout();
|
||||
IRBuilder<> Builder(CI);
|
||||
|
||||
auto CheckIfIdentifierPtr = [](const Value *V) {
|
||||
auto *CI = dyn_cast<CallInst>(V);
|
||||
if (!CI)
|
||||
return false;
|
||||
Function *Callee = CI->getCalledFunction();
|
||||
return Callee->getName().startswith("__kmpc_host_rpc_wrapper_");
|
||||
};
|
||||
|
||||
auto CheckIfAlloca = [](const Value *V) {
|
||||
auto *CI = dyn_cast<CallInst>(V);
|
||||
if (!CI)
|
||||
return false;
|
||||
Function *Callee = CI->getCalledFunction();
|
||||
return Callee->getName() == "__kmpc_alloc_shared" ||
|
||||
Callee->getName() == "malloc";
|
||||
};
|
||||
|
||||
SmallVector<HostRPCArgInfo> ArgInfos;
|
||||
bool IsConstantArgInfo = true;
|
||||
|
||||
for (Value *Op : CI->args()) {
|
||||
if (!Op->getType()->isPointerTy()) {
|
||||
HostRPCArgInfo AI{
|
||||
ConstantInt::get(Int64Ty,
|
||||
OMPHostRPCArgType::OMP_HOST_RPC_ARG_SCALAR),
|
||||
ConstantInt::getNullValue(Int64Ty)};
|
||||
ArgInfos.push_back(std::move(AI));
|
||||
continue;
|
||||
}
|
||||
|
||||
Value *SizeVal = nullptr;
|
||||
OMPHostRPCArgType ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TOFROM;
|
||||
|
||||
SmallVector<const Value *> Objects;
|
||||
getUnderlyingObjects(Op, Objects);
|
||||
|
||||
// TODO: Handle phi node
|
||||
if (Objects.size() != 1)
|
||||
llvm_unreachable("we can't handle phi node yet");
|
||||
|
||||
auto *Obj = Objects.front();
|
||||
if (CheckIfIdentifierPtr(Obj)) {
|
||||
ArgType = OMP_HOST_RPC_ARG_SCALAR;
|
||||
SizeVal = ConstantInt::getNullValue(Int64Ty);
|
||||
} else if (CheckIfAlloca(Obj)) {
|
||||
auto *CI = dyn_cast<CallInst>(Obj);
|
||||
SizeVal = CI->getOperand(0);
|
||||
if (!isa<Constant>(SizeVal))
|
||||
IsConstantArgInfo = false;
|
||||
} else {
|
||||
if (auto *GV = dyn_cast<GlobalVariable>(Obj)) {
|
||||
SizeVal = ConstantInt::get(Int64Ty,
|
||||
DL.getTypeStoreSize(GV->getValueType()));
|
||||
if (GV->isConstant())
|
||||
ArgType = OMP_HOST_RPC_ARG_PTR_COPY_TO;
|
||||
if (GV->isConstant() && GV->hasInitializer()) {
|
||||
// TODO: If the global variable is contant, we can do some
|
||||
// optimization.
|
||||
}
|
||||
} else {
|
||||
// TODO: fix that when it occurs
|
||||
llvm_unreachable("cannot handle unknown type");
|
||||
}
|
||||
}
|
||||
|
||||
HostRPCArgInfo AI{ConstantInt::get(Int64Ty, ArgType), SizeVal};
|
||||
ArgInfos.push_back(std::move(AI));
|
||||
}
|
||||
|
||||
Value *ArgInfo = nullptr;
|
||||
|
||||
if (!IsConstantArgInfo) {
|
||||
ArgInfo = Builder.CreateAlloca(
|
||||
ArgInfoTy, ConstantInt::get(Int64Ty, NumArgs), "arg_info");
|
||||
for (unsigned I = 0; I < NumArgs; ++I) {
|
||||
Value *AII = GetElementPtrInst::Create(
|
||||
ArrayType::get(ArgInfoTy, NumArgs), ArgInfo,
|
||||
{ConstantInt::getNullValue(Int64Ty), ConstantInt::get(Int64Ty, I)});
|
||||
Value *AIIType = GetElementPtrInst::Create(
|
||||
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 0)});
|
||||
Value *AIISize = GetElementPtrInst::Create(
|
||||
ArgInfoTy, AII, {ConstantInt::get(Int64Ty, 1)});
|
||||
Builder.Insert(AII);
|
||||
Builder.Insert(AIIType);
|
||||
Builder.Insert(AIISize);
|
||||
Builder.CreateStore(ArgInfos[I].Type, AIIType);
|
||||
Builder.CreateStore(ArgInfos[I].Size, AIISize);
|
||||
}
|
||||
} else {
|
||||
SmallVector<Constant *> ArgInfoInitVar;
|
||||
for (auto &AI : ArgInfos) {
|
||||
auto *CS =
|
||||
ConstantStruct::get(ArgInfoTy, {AI.Type, cast<Constant>(AI.Size)});
|
||||
ArgInfoInitVar.push_back(CS);
|
||||
}
|
||||
Constant *ArgInfoInit = ConstantArray::get(
|
||||
ArrayType::get(ArgInfoTy, NumArgs), ArgInfoInitVar);
|
||||
ArgInfo = new GlobalVariable(
|
||||
M, ArrayType::get(ArgInfoTy, NumArgs), /* isConstant */ true,
|
||||
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info");
|
||||
}
|
||||
|
||||
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
|
||||
ArgInfo};
|
||||
for (Value *Op : CI->args())
|
||||
Args.push_back(Op);
|
||||
|
||||
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
|
||||
|
||||
CI->replaceAllUsesWith(NewCall);
|
||||
CI->eraseFromParent();
|
||||
}
|
||||
|
||||
F->eraseFromParent();
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
Function *HostRPC::getDeviceWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI) {
|
||||
Function *WrapperFn = M.getFunction(WrapperName);
|
||||
if (WrapperFn)
|
||||
return WrapperFn;
|
||||
|
||||
// return_type device_wrapper(int32_t call_no, void *arg_info, ...)
|
||||
SmallVector<Type *> Params{Int32Ty, Int8PtrTy};
|
||||
Params.append(CSI.Params);
|
||||
|
||||
Type *RetTy = F->getReturnType();
|
||||
|
||||
FunctionType *FT = FunctionType::get(RetTy, Params, /*isVarArg*/ false);
|
||||
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::WeakODRLinkage,
|
||||
WrapperName, M);
|
||||
|
||||
// Emit the body of the device wrapper
|
||||
IRBuilder<> Builder(Context);
|
||||
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
|
||||
Builder.SetInsertPoint(EntryBB);
|
||||
|
||||
// skip call_no and arg_info.
|
||||
constexpr const unsigned NumArgSkipped = 2;
|
||||
|
||||
Value *Desc = nullptr;
|
||||
{
|
||||
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
|
||||
Desc = Builder.CreateCall(
|
||||
Fn,
|
||||
{WrapperFn->getArg(0),
|
||||
ConstantInt::get(Int32Ty, WrapperFn->arg_size() - NumArgSkipped),
|
||||
WrapperFn->getArg(1)},
|
||||
"desc");
|
||||
}
|
||||
|
||||
{
|
||||
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_add_arg];
|
||||
for (unsigned I = NumArgSkipped; I < WrapperFn->arg_size(); ++I) {
|
||||
Value *V = convertToInt64Ty(Builder, WrapperFn->getArg(I));
|
||||
Builder.CreateCall(Fn, {Desc, V, ConstantInt::get(Int32Ty, I)});
|
||||
}
|
||||
}
|
||||
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_send_and_wait], {Desc});
|
||||
|
||||
if (RetTy->isVoidTy()) {
|
||||
Builder.CreateRetVoid();
|
||||
return WrapperFn;
|
||||
}
|
||||
|
||||
Value *RetVal =
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_ret_val], {Desc});
|
||||
if (RetTy != RetVal->getType())
|
||||
RetVal = convertFromInt64TyTo(Builder, RetVal, RetTy);
|
||||
|
||||
Builder.CreateRet(RetVal);
|
||||
|
||||
return WrapperFn;
|
||||
}
|
||||
|
||||
Function *HostRPC::getHostWrapperFunction(StringRef WrapperName, Function *F,
|
||||
CallSiteInfo &CSI) {
|
||||
Function *WrapperFn = HostModule->getFunction(WrapperName);
|
||||
if (WrapperFn)
|
||||
return WrapperFn;
|
||||
|
||||
SmallVector<Type *> Params{Int8PtrTy};
|
||||
FunctionType *FT = FunctionType::get(VoidTy, Params, /* isVarArg */ false);
|
||||
WrapperFn = Function::Create(FT, GlobalValue::LinkageTypes::ExternalLinkage,
|
||||
WrapperName, *HostModule);
|
||||
|
||||
Value *Desc = WrapperFn->getArg(0);
|
||||
|
||||
// Emit the body of the host wrapper
|
||||
IRBuilder<> Builder(Context);
|
||||
BasicBlock *EntryBB = BasicBlock::Create(Context, "entry", WrapperFn);
|
||||
Builder.SetInsertPoint(EntryBB);
|
||||
|
||||
SmallVector<Value *> Args;
|
||||
for (unsigned I = 0; I < CSI.CI->arg_size(); ++I) {
|
||||
Value *V = Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_get_arg],
|
||||
{Desc, ConstantInt::get(Int32Ty, I)});
|
||||
Args.push_back(convertFromInt64TyTo(Builder, V, CSI.Params[I]));
|
||||
}
|
||||
|
||||
// The host callee that will be called eventually by the host wrapper.
|
||||
Function *HostCallee = HostModule->getFunction(F->getName());
|
||||
if (!HostCallee)
|
||||
HostCallee = Function::Create(F->getFunctionType(), F->getLinkage(),
|
||||
F->getName(), *HostModule);
|
||||
|
||||
Value *RetVal = Builder.CreateCall(HostCallee, Args);
|
||||
RetVal = convertToInt64Ty(Builder, RetVal);
|
||||
Builder.CreateCall(RFIs[OMPRTL___kmpc_host_rpc_set_ret_val], {Desc, RetVal});
|
||||
Builder.CreateRetVoid();
|
||||
|
||||
return WrapperFn;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
PreservedAnalyses HostRPCPass::run(Module &M, ModuleAnalysisManager &AM) {
|
||||
printf("[HostRPCPass] HostModule=%p\n", HostModule);
|
||||
if (!HostModule)
|
||||
return PreservedAnalyses::all();
|
||||
|
||||
HostRPC RPC(M);
|
||||
bool Changed = RPC.run();
|
||||
|
||||
return Changed ? PreservedAnalyses::all() : PreservedAnalyses::none();
|
||||
}
|
||||
@@ -100,6 +100,7 @@ set(include_files
|
||||
set(src_files
|
||||
${source_directory}/Configuration.cpp
|
||||
${source_directory}/Debug.cpp
|
||||
${source_directory}/HostRPC.cpp
|
||||
${source_directory}/Kernel.cpp
|
||||
${source_directory}/LibC.cpp
|
||||
${source_directory}/Mapping.cpp
|
||||
|
||||
@@ -351,6 +351,15 @@ int32_t __kmpc_cancel(IdentTy *Loc, int32_t TId, int32_t CancelVal);
|
||||
int32_t __kmpc_shuffle_int32(int32_t val, int16_t delta, int16_t size);
|
||||
int64_t __kmpc_shuffle_int64(int64_t val, int16_t delta, int16_t size);
|
||||
///}
|
||||
|
||||
/// Host RPC
|
||||
///
|
||||
/// {
|
||||
void *__kmpc_host_rpc_get_desc(int32_t CallNo, int32_t NumArgs, void *ArgInfo);
|
||||
void __kmpc_host_rpc_add_arg(void *Desc, int64_t Arg, int32_t ArgNum);
|
||||
void __kmpc_host_rpc_send_and_wait(void *Desc);
|
||||
int64_t __kmpc_host_rpc_get_ret_val(void *Desc);
|
||||
/// }
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
25
openmp/libomptarget/DeviceRTL/src/HostRPC.cpp
Normal file
25
openmp/libomptarget/DeviceRTL/src/HostRPC.cpp
Normal file
@@ -0,0 +1,25 @@
|
||||
//===------- HostRPC.cpp - Implementation of host RPC ------------- C++ ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "Types.h"
|
||||
|
||||
#pragma omp begin declare target device_type(nohost)
|
||||
|
||||
extern "C" {
|
||||
void *__kmpc_host_rpc_get_desc(int32_t CallNo, int32_t NumArgs, void *ArgInfo) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
void __kmpc_host_rpc_add_arg(void *Desc, int64_t Arg, int32_t ArgNum) {}
|
||||
|
||||
void __kmpc_host_rpc_send_and_wait(void *Desc) {}
|
||||
|
||||
int64_t __kmpc_host_rpc_get_ret_val(void *Desc) { return 0; }
|
||||
}
|
||||
|
||||
#pragma omp end declare target
|
||||
Reference in New Issue
Block a user