Fixe HostRPC constant adress spaces.

Cast constant declare in adress space 1 to adress space 0.
This is needed for rpc calls as a function call
cannot pass as arguments to a function a pointer in address space 1.
This commit is contained in:
Nicolas Marie
2024-08-30 10:34:28 -07:00
parent 7104affc8e
commit f005107d46

View File

@@ -74,7 +74,9 @@ static constexpr const char *InternalPrefix[] = {
"__kmp", "llvm.", "nvm.",
"omp_", "vprintf", "malloc",
"free", "__keep_alive", "__llvm_omp_vprintf",
"rpc_", "MPI_", "fprintf", "sprintf"
"rpc_", "MPI_", "fprintf", "sprintf",
"sscanf", "fopen", "fgets", "fclose", "usleep",
//"time", "localtime", "asctime", "uname",
};
bool isInternalFunction(Function &F) {
@@ -673,8 +675,7 @@ bool HostRPC::rewriteWithHostRPC(Function *F) {
Value *ArgInfoVal = nullptr;
if (!IsConstantArgInfo) {
ArgInfoVal = Builder.CreateAlloca(Int8PtrTy, getConstantInt64(NumArgs),
"arg_info");
ArgInfoVal = Builder.CreateAddrSpaceCast(Builder.CreateAlloca(Int8PtrTy, getConstantInt64(NumArgs), "arg_info"), PointerType::get(M.getContext(), 0));
for (unsigned I = 0; I < NumArgs; ++I) {
auto &AII = ArgInfo[I];
Value *Next = NullPtr;
@@ -717,13 +718,11 @@ bool HostRPC::rewriteWithHostRPC(Function *F) {
auto *CS =
ConstantStruct::get(ArgInfoTy, {convertToInt64Ty(Arg), AI.Type,
cast<Constant>(AI.Size), Last});
auto *GV = new GlobalVariable(
M, ArgInfoTy, /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, CS, "",
nullptr, GlobalValue::ThreadLocalMode::NotThreadLocal, 0);
auto *GV = new GlobalVariable(M, ArgInfoTy, /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, CS);
// force adress space 0 on AMD GPU
// insted of address space 1 for globals
Last = GV;
Last = ConstantExpr::getAddrSpaceCast(GV, PointerType::get(M.getContext(), 0));
}
LLVM_DEBUG({
dbgs() << "[HostRPC] ArgInfoInitVar:" << *Last << "\n";
@@ -734,16 +733,24 @@ bool HostRPC::rewriteWithHostRPC(Function *F) {
Constant *ArgInfoInit = ConstantArray::get(
ArrayType::get(Int8PtrTy, NumArgs), ArgInfoInitVar);
ArgInfoVal = new GlobalVariable(
M, ArrayType::get(Int8PtrTy, NumArgs), /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit, "arg_info",
nullptr, GlobalValue::ThreadLocalMode::NotThreadLocal, 0);
ArgInfoVal = ConstantExpr::getAddrSpaceCast(
new GlobalVariable(M, ArrayType::get(Int8PtrTy, NumArgs), /* isConstant */ true,
GlobalValue::LinkageTypes::InternalLinkage, ArgInfoInit),
PointerType::get(M.getContext(), 0));
}
SmallVector<Value *> Args{ConstantInt::get(Int32Ty, WrapperNumber),
ArgInfoVal};
for (Value *Operand : CI->args())
for (Value *Operand : CI->args()) {
Args.push_back(Operand);
LLVM_DEBUG({
dbgs() << "[HostRPC] Arg:" << *Operand << "|" << *Operand->getType() << "\n";
});
}
LLVM_DEBUG({
dbgs() << "[HostRPC] Function:" << *DeviceWrapperFn << "\n";
});
CallInst *NewCall = Builder.CreateCall(DeviceWrapperFn, Args);
@@ -780,7 +787,7 @@ Function *HostRPC::getDeviceWrapperFunction(StringRef WrapperName, Function *F,
{
Function *Fn = RFIs[OMPRTL___kmpc_host_rpc_get_desc];
//LLVM_DEBUG({dbgs() << "[HostRPC] Building: rpc get desc: " << Fn->getName() << "\n"; });
for (unsigned i = 0; i < 3; ++i)
//for (unsigned i = 0; i < 3; ++i)
//LLVM_DEBUG({dbgs() << "ParamI: " << *(Fn->getFunctionType()->getParamType(i)) << "\n"; });
Desc = Builder.CreateCall(