Try to add support for rpc vprintf.

Try to add support for vprintf using rpc without implementing variadic arguments.
Using Buildtin to transform fprintf call to vfprintf & then replace vfprintf call to a call to __omp_vprintf.
This is the methode currently used to make printf work using rpcs.
fprintf is not recognise as a builtin for now, so it does not work.
This commit is contained in:
Nicolas Marie
2024-08-30 10:40:22 -07:00
parent e9d0f287b4
commit daac07a682
4 changed files with 96 additions and 2 deletions

View File

@@ -5864,7 +5864,14 @@ RValue CodeGenFunction::EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
if (getTarget().getTriple().isAMDGCN() && getLangOpts().HIP)
return EmitAMDGPUDevicePrintfCallExpr(E);
}
break;
case Builtin::BI__builtin_fprintf:
case Builtin::BIfprintf:
if (getTarget().getTriple().isNVPTX() ||
getTarget().getTriple().isAMDGCN()) {
if (getLangOpts().OpenMPIsTargetDevice)
return EmitOpenMPDeviceFPrintfCallExpr(E);
}
break;
case Builtin::BI__builtin_canonicalize:
case Builtin::BI__builtin_canonicalizef:

View File

@@ -64,6 +64,28 @@ llvm::Function *GetOpenMPVprintfDeclaration(CodeGenModule &CGM) {
VprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);
}
llvm::Function *GetOpenMPVfprintfDeclaration(CodeGenModule &CGM) {
const char *Name = "__llvm_omp_vfprintf";
llvm::Module &M = CGM.getModule();
llvm::Type *ArgTypes[] = {llvm::PointerType::getUnqual(M.getContext()),
llvm::PointerType::getUnqual(M.getContext()),
llvm::PointerType::getUnqual(M.getContext()),
llvm::Type::getInt32Ty(M.getContext())};
llvm::FunctionType *VfprintfFuncType = llvm::FunctionType::get(
llvm::Type::getInt32Ty(M.getContext()), ArgTypes, false);
if (auto *F = M.getFunction(Name)) {
if (F->getFunctionType() != VfprintfFuncType) {
CGM.Error(SourceLocation(),
"Invalid type declaration for __llvm_omp_vfprintf");
return nullptr;
}
return F;
}
return llvm::Function::Create(
VfprintfFuncType, llvm::GlobalVariable::ExternalLinkage, Name, &M);
}
// Transforms a call to printf into a call to the NVPTX vprintf syscall (which
// isn't particularly special; it's invoked just like a regular function).
// vprintf takes two args: A format string, and a pointer to a buffer containing
@@ -170,6 +192,46 @@ RValue EmitDevicePrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF,
}
return RValue::get(Builder.CreateCall(Decl, Vec));
}
RValue EmitDeviceFPrintfCallExpr(const CallExpr *E, CodeGenFunction *CGF,
llvm::Function *Decl, bool WithSizeArg) {
CodeGenModule &CGM = CGF->CGM;
CGBuilderTy &Builder = CGF->Builder;
assert(E->getBuiltinCallee() == Builtin::BIfprintf ||
E->getBuiltinCallee() == Builtin::BI__builtin_fprintf);
assert(E->getNumArgs() >= 2); // fprintf always has at least one arg.
// Uses the same format as nvptx for the argument packing, but also passes
// an i32 for the total size of the passed pointer
CallArgList Args;
CGF->EmitCallArgs(Args,
E->getDirectCallee()->getType()->getAs<FunctionProtoType>(),
E->arguments(), E->getDirectCallee(),
/* ParamsToSkip = */ 0);
// We don't know how to emit non-scalar varargs.
if (containsNonScalarVarargs(CGF, Args)) {
CGM.ErrorUnsupported(E, "non-scalar arg to printf");
return RValue::get(llvm::ConstantInt::get(CGF->IntTy, 0));
}
auto r = packArgsIntoNVPTXFormatBuffer(CGF, Args);
llvm::Value *BufferPtr = r.first;
llvm::SmallVector<llvm::Value *, 3> Vec = {
Args[0].getRValue(*CGF).getScalarVal(), BufferPtr};
if (WithSizeArg) {
// Passing > 32bit of data as a local alloca doesn't work for nvptx or
// amdgpu
llvm::Constant *Size =
llvm::ConstantInt::get(llvm::Type::getInt32Ty(CGM.getLLVMContext()),
static_cast<uint32_t>(r.second.getFixedValue()));
Vec.push_back(Size);
}
return RValue::get(Builder.CreateCall(Decl, Vec));
}
} // namespace
RValue CodeGenFunction::EmitNVPTXDevicePrintfCallExpr(const CallExpr *E) {
@@ -218,3 +280,11 @@ RValue CodeGenFunction::EmitOpenMPDevicePrintfCallExpr(const CallExpr *E) {
return EmitDevicePrintfCallExpr(E, this, GetOpenMPVprintfDeclaration(CGM),
true);
}
RValue CodeGenFunction::EmitOpenMPDeviceFPrintfCallExpr(const CallExpr *E) {
assert(getTarget().getTriple().isNVPTX() ||
getTarget().getTriple().isAMDGCN());
return EmitDeviceFPrintfCallExpr(E, this, GetOpenMPVfprintfDeclaration(CGM),
true);
}

View File

@@ -4469,6 +4469,7 @@ public:
RValue EmitNVPTXDevicePrintfCallExpr(const CallExpr *E);
RValue EmitAMDGPUDevicePrintfCallExpr(const CallExpr *E);
RValue EmitOpenMPDevicePrintfCallExpr(const CallExpr *E);
RValue EmitOpenMPDeviceFPrintfCallExpr(const CallExpr *E);
RValue EmitBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID,
const CallExpr *E, ReturnValueSlot ReturnValue);

View File

@@ -16,8 +16,11 @@ using namespace ompx;
#pragma omp begin declare target device_type(nohost)
struct FILE;
namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t);
int32_t omp_vfprintf(FILE* stream, const char *Format, void *Arguments, uint32_t Size);
}
#pragma omp begin declare variant match( \
@@ -28,6 +31,9 @@ namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t) {
return vprintf(Format, Arguments);
}
int32_t omp_vfprintf(FILE* stream, const char *Format, void *Arguments, uint32_t Size) {
return -1;
}
} // namespace impl
#pragma omp end declare variant
@@ -42,6 +48,9 @@ namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
return rpc_fprintf(stdout, Format, Arguments, Size);
}
int32_t omp_vfprintf(FILE* stream, const char *Format, void *Arguments, uint32_t Size) {
return rpc_fprintf(stream, Format, Arguments, Size);
}
} // namespace impl
#else
// We do not have a vprintf implementation for AMD GPU so we use a stub.
@@ -49,6 +58,9 @@ namespace impl {
int32_t omp_vprintf(const char *Format, void *Arguments, uint32_t) {
return -1;
}
int32_t omp_vfprintf(FILE* stream, const char *Format, void *Arguments, uint32_t Size) {
return -1;
}
} // namespace impl
#endif
#pragma omp end declare variant
@@ -58,7 +70,7 @@ int StdInDummyVar;
int StdOutDummyVar;
int StdErrDummyVar;
struct FILE;
__attribute__((used, retain, weak, visibility("protected"))) FILE *stdin =
(FILE *)&StdInDummyVar;
__attribute__((used, retain, weak, visibility("protected"))) FILE *stdout =
@@ -386,6 +398,10 @@ int32_t __llvm_omp_vprintf(const char *Format, void *Arguments, uint32_t Size) {
return impl::omp_vprintf(Format, Arguments, Size);
}
int32_t __llvm_omp_vfprintf(FILE *stream, const char *Format, void *Arguments, uint32_t Size) {
return impl::omp_vfprintf(stream, Format, Arguments, Size);
}
// -----------------------------------------------------------------------------
#ifndef ULONG_MAX