Compare commits

...

1 Commits

Author SHA1 Message Date
Shilei Tian
af01c4c979 [LLVM][WIP] Introduced heterogenous IR module 2021-03-04 22:43:56 -05:00
22 changed files with 1184 additions and 231 deletions

View File

@@ -794,6 +794,9 @@ endif()
# compatibility.
set(LLVM_ENABLE_NEW_PASS_MANAGER ${ENABLE_EXPERIMENTAL_NEW_PASS_MANAGER})
# By default a heterogenous module can support 32 targets
set(LLVM_MODULE_NUM_TARGETS 32)
# Configure the three LLVM configuration header files.
configure_file(
${LLVM_MAIN_INCLUDE_DIR}/llvm/Config/config.h.cmake

View File

@@ -117,6 +117,12 @@ enum ModuleCodes {
// IFUNC: [ifunc value type, addrspace, resolver val#, linkage, visibility]
MODULE_CODE_IFUNC = 18,
// NUM_TARGETS: [num of targets]
MODULE_CODE_NUM_TARGETS = 19,
// TARGET_ID: [target id]
MODULE_CODE_TARGET_ID = 20,
};
/// PARAMATTR blocks have code for defining a parameter attribute set.

View File

@@ -94,4 +94,7 @@
/* Define to 1 to enable the experimental new pass manager by default */
#cmakedefine01 LLVM_ENABLE_NEW_PASS_MANAGER
/* Define as the number of targets that a heterogenous module can support */
#cmakedefine LLVM_MODULE_NUM_TARGETS ${LLVM_MODULE_NUM_TARGETS}
#endif

View File

@@ -65,7 +65,6 @@ public:
/// unknown values that are passed to the callee.
using ParameterEncodingTy = SmallVector<int, 0>;
ParameterEncodingTy ParameterEncoding;
};
private:
@@ -128,6 +127,12 @@ public:
return !CI.ParameterEncoding.empty();
}
/// Return true if this ACS represents a heterogenous callback call.
bool isHeterogenousCallbackCall() const {
return isCallbackCall() &&
CI.ParameterEncoding[0] != CI.ParameterEncoding.back();
}
/// Return true if @p UI is the use that defines the callee of this ACS.
bool isCallee(Value::const_user_iterator UI) const {
return isCallee(&UI.getUse());
@@ -141,6 +146,10 @@ public:
assert(!CI.ParameterEncoding.empty() &&
"Callback without parameter encoding!");
if (isHeterogenousCallbackCall())
return static_cast<int>(CB->getDataOperandNo(U)) ==
CI.ParameterEncoding.back();
// If the use is actually in a constant cast expression which itself
// has only one use, we look through the constant cast expression.
if (auto *CE = dyn_cast<ConstantExpr>(U->getUser()))
@@ -154,8 +163,9 @@ public:
unsigned getNumArgOperands() const {
if (isDirectCall())
return CB->getNumArgOperands();
// Subtract 1 for the callee encoding.
return CI.ParameterEncoding.size() - 1;
// Subtract 1 for the callee encoding, and another 1 for the heterogenous
// callee encoding at the end.
return CI.ParameterEncoding.size() - 2;
}
/// Return the operand index of the underlying instruction associated with @p
@@ -194,7 +204,7 @@ public:
int getCallArgOperandNoForCallee() const {
assert(isCallbackCall());
assert(CI.ParameterEncoding.size() && CI.ParameterEncoding[0] >= 0);
return CI.ParameterEncoding[0];
return CI.ParameterEncoding.back();
}
/// Return the use of the callee value in the underlying instruction. Only
@@ -210,7 +220,7 @@ public:
Value *getCalledOperand() const {
if (isDirectCall())
return CB->getCalledOperand();
return CB->getArgOperand(getCallArgOperandNoForCallee());
return CB->getOperand(getCallArgOperandNoForCallee());
}
/// Return the function being called if this is a direct call, otherwise

View File

@@ -265,6 +265,9 @@ public:
bool isIllegalInteger(uint64_t Width) const { return !isLegalInteger(Width); }
/// Return true if this DataLayout is compatible with \p Other.
bool isCompatibleWith(const DataLayout &Other) const;
/// Returns true if the given alignment exceeds the natural stack alignment.
bool exceedsNaturalStackAlignment(Align Alignment) const {
return StackNaturalAlign && (Alignment > *StackNaturalAlign);

View File

@@ -1103,6 +1103,15 @@ struct OperandBundleUse {
return getTagID() == LLVMContext::OB_cfguardtarget;
}
/// Return true if the use \p U is in the operand bundle.
bool isUseInBundle(const Use * U) const {
for (const Use &UseInThisBundle : Inputs)
if (*U == UseInThisBundle)
return true;
return false;
}
private:
/// Pointer to an entry in LLVMContextImpl::getOrInsertBundleTag.
StringMapEntry<uint32_t> *Tag;

View File

@@ -65,8 +65,8 @@ class VersionTuple;
/// variable is destroyed, it should have no entries in the GlobalValueRefMap.
/// The main container class for the LLVM Intermediate Representation.
class Module {
/// @name Types And Enumerations
/// @{
/// @name Types And Enumerations
/// @{
public:
/// The type for the list of global variables.
using GlobalListType = SymbolTableList<GlobalVariable>;
@@ -171,48 +171,66 @@ public:
: Behavior(B), Key(K), Val(V) {}
};
/// @}
/// @name Member Variables
/// @{
/// @}
/// @name Member Variables
/// @{
private:
LLVMContext &Context; ///< The LLVMContext from which types and
///< constants are allocated.
GlobalListType GlobalList; ///< The Global Variables in the module
FunctionListType FunctionList; ///< The Functions in the module
AliasListType AliasList; ///< The Aliases in the module
IFuncListType IFuncList; ///< The IFuncs in the module
NamedMDListType NamedMDList; ///< The named metadata in the module
std::string GlobalScopeAsm; ///< Inline Asm at global scope.
LLVMContext &Context; ///< The LLVMContext from which types and
///< constants are allocated.
GlobalListType GlobalList; ///< The Global Variables in the module
FunctionListType FunctionList; ///< The Functions in the module
AliasListType AliasList; ///< The Aliases in the module
IFuncListType IFuncList; ///< The IFuncs in the module
NamedMDListType NamedMDList; ///< The named metadata in the module
std::string GlobalScopeAsm; ///< Inline Asm at global scope.
std::unique_ptr<ValueSymbolTable> ValSymTab; ///< Symbol table for values
ComdatSymTabType ComdatSymTab; ///< Symbol table for COMDATs
ComdatSymTabType ComdatSymTab; ///< Symbol table for COMDATs
std::unique_ptr<MemoryBuffer>
OwnedMemoryBuffer; ///< Memory buffer directly owned by this
///< module, for legacy clients only.
OwnedMemoryBuffer; ///< Memory buffer directly owned by this
///< module, for legacy clients only.
std::unique_ptr<GVMaterializer>
Materializer; ///< Used to materialize GlobalValues
std::string ModuleID; ///< Human readable identifier for the module
std::string SourceFileName; ///< Original source file name for module,
///< recorded in bitcode.
std::string TargetTriple; ///< Platform target triple Module compiled on
///< Format: (arch)(sub)-(vendor)-(sys0-(abi)
NamedMDSymTabType NamedMDSymTab; ///< NamedMDNode names.
DataLayout DL; ///< DataLayout associated with the module
Materializer; ///< Used to materialize GlobalValues
std::string ModuleID; ///< Human readable identifier for the module
std::string SourceFileName; ///< Original source file name for module,
///< recorded in bitcode.
std::string TargetTriple; ///< Platform target triple Module compiled on
///< Format: (arch)(sub)-(vendor)-(sys0-(abi)
NamedMDSymTabType NamedMDSymTab; ///< NamedMDNode names.
DataLayout DL; ///< DataLayout associated with the module
/// @}
/// @name Members for heterogenous module
/// @{
bool IsHeterogenousModule; ///< Whether this module is heterogenous
SmallVector<DataLayout, LLVM_MODULE_NUM_TARGETS>
DLs; ///< DataLayout associated with the heterogenous module
SmallVector<std::string, LLVM_MODULE_NUM_TARGETS>
TargetTriples; ///< Platform target triple the heterogenous module
SmallVector<std::string, LLVM_MODULE_NUM_TARGETS>
GlobalScopeAsms; ///< Inline Asm at global scope.
unsigned
ActiveTarget; ///< The active target id if heterogenous module is enabled
unsigned NumTargets;
friend class Constant;
/// @}
/// @name Constructors
/// @{
/// @}
/// @name Constructors
/// @{
public:
/// The Module constructor. Note that there is no default constructor. You
/// must provide a name for the module upon construction.
explicit Module(StringRef ModuleID, LLVMContext& C);
explicit Module(StringRef ModuleID, LLVMContext &C);
/// The module destructor. This will dropAllReferences.
~Module();
/// @}
/// @name Module Level Accessors
/// @{
/// @}
/// @name Module Level Accessors
/// @{
/// Return if the module is heterogenous module
bool isHeterogenousModule() const { return IsHeterogenousModule; }
/// Get the module identifier which is, essentially, the name of the module.
/// @returns the module identifier as a string
@@ -238,15 +256,64 @@ public:
/// Get the data layout string for the module's target platform. This is
/// equivalent to getDataLayout()->getStringRepresentation().
const std::string &getDataLayoutStr() const {
// If it is a heterogenous module, we need to return the one of active
// target
if (IsHeterogenousModule)
return getDataLayoutStr(ActiveTarget);
// Otherwise, the original one will be returned
return DL.getStringRepresentation();
}
/// Get the data layout string for the TargetId-th module's target platform.
const std::string &getDataLayoutStr(unsigned TargetId) const {
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return DL.getStringRepresentation();
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to get the data "
"layout string of non-first target");
return DLs[TargetId - 1].getStringRepresentation();
}
/// Get the data layout for the module's target platform.
const DataLayout &getDataLayout() const;
/// Get the data layout for the TargetId-th module's target platform.
const DataLayout &getDataLayout(unsigned TargetId) const;
/// Get the target triple which is a string describing the target host.
/// @returns a string containing the target triple.
const std::string &getTargetTriple() const { return TargetTriple; }
const std::string &getTargetTriple() const {
// If it is a heterogenous module, we need to return the one of active
// target
if (IsHeterogenousModule)
return getTargetTriple(ActiveTarget);
// Otherwise, the original one will be returned
return TargetTriple;
}
/// Get the TargetId-th target triple which is a string describing the target
/// host.
/// @returns a string containing the target triple.
const std::string &getTargetTriple(unsigned TargetId) const {
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return TargetTriple;
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to get the target "
"triple of non-first target");
return TargetTriples[TargetId - 1];
}
/// Get the global data context.
/// @returns LLVMContext - a container for LLVM's global information
@@ -254,7 +321,31 @@ public:
/// Get any module-scope inline assembly blocks.
/// @returns a string containing the module-scope inline assembly blocks.
const std::string &getModuleInlineAsm() const { return GlobalScopeAsm; }
const std::string &getModuleInlineAsm() const {
// If it is a heterogenous module, we need to return the one of active
// target
if (IsHeterogenousModule)
return getModuleInlineAsm(ActiveTarget);
// Otherwise, the original one will be returned
return GlobalScopeAsm;
}
/// Get any module-scope inline assembly blocks from TargetId-th module.
/// @returns a string containing the module-scope inline assembly blocks.
const std::string &getModuleInlineAsm(unsigned TargetId) const {
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return GlobalScopeAsm;
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to get the inline "
"assembly blocks of non-first target");
return GlobalScopeAsms[TargetId - 1];
}
/// Get a RandomNumberGenerator salted for use with this module. The
/// RNG can be seeded via -rng-seed=<uint64> and is salted with the
@@ -278,6 +369,9 @@ public:
/// @name Module Level Mutators
/// @{
/// Mark the module as heterogenous
void markHeterogenous() { IsHeterogenousModule = true; }
/// Set the module identifier.
void setModuleIdentifier(StringRef ID) { ModuleID = std::string(ID); }
@@ -287,34 +381,86 @@ public:
/// Set the data layout
void setDataLayout(StringRef Desc);
void setDataLayout(const DataLayout &Other);
void setDataLayout(StringRef Desc, unsigned TargetId);
void setDataLayout(const DataLayout &Other, unsigned TargetId);
/// Set the target triple.
// Target triple should only be set when building the module, so we don't need
// to use ActiveTarget here.
void setTargetTriple(StringRef T) { TargetTriple = std::string(T); }
void setTargetTriple(StringRef T, unsigned TargetId) {
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return setTargetTriple(T);
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to set the target "
"triple of non-first target");
TargetTriples[TargetId - 1] = std::string(T);
}
/// Set the module-scope inline assembly blocks.
/// A trailing newline is added if the input doesn't have one.
void setModuleInlineAsm(StringRef Asm) {
if (IsHeterogenousModule) {
setModuleInlineAsm(Asm, ActiveTarget);
return;
}
GlobalScopeAsm = std::string(Asm);
if (!GlobalScopeAsm.empty() && GlobalScopeAsm.back() != '\n')
GlobalScopeAsm += '\n';
}
/// Set the module-scope inline assembly blocks of TargetId-th target.
/// A trailing newline is added if the input doesn't have one.
void setModuleInlineAsm(StringRef Asm, unsigned TargetId) {
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
std::string &AsmRef =
(TargetId == 0) ? GlobalScopeAsm : GlobalScopeAsms[TargetId - 1];
AsmRef = std::string(Asm);
if (!AsmRef.empty() && AsmRef.back() != '\n')
AsmRef += '\n';
}
/// Append to the module-scope inline assembly blocks.
/// A trailing newline is added if the input doesn't have one.
void appendModuleInlineAsm(StringRef Asm) {
if (IsHeterogenousModule)
return appendModuleInlineAsm(Asm, ActiveTarget);
GlobalScopeAsm += Asm;
if (!GlobalScopeAsm.empty() && GlobalScopeAsm.back() != '\n')
GlobalScopeAsm += '\n';
}
/// @}
/// @name Generic Value Accessors
/// @{
/// Append to the module-scope inline assembly blocks of TargetId-th target.
/// A trailing newline is added if the input doesn't have one.
void appendModuleInlineAsm(StringRef Asm, unsigned TargetId) {
std::string *AsmRef = &GlobalScopeAsm;
if (TargetId)
AsmRef = &GlobalScopeAsms[TargetId - 1];
*AsmRef += Asm;
if (!AsmRef->empty() && AsmRef->back() != '\n')
*AsmRef += '\n';
}
/// @}
/// @name Generic Value Accessors
/// @{
/// Return the global value in the module with the specified name, of
/// arbitrary type. This method returns null if a global with the specified
/// name is not found.
GlobalValue *getNamedValue(StringRef Name) const;
GlobalValue *getNamedValue(StringRef Name, unsigned TargetId) const;
/// Return a unique non-zero ID for the specified metadata kind. This ID is
/// uniqued across modules in the current LLVMContext.
@@ -331,9 +477,9 @@ public:
std::vector<StructType *> getIdentifiedStructTypes() const;
/// @}
/// @name Function Accessors
/// @{
/// @}
/// @name Function Accessors
/// @{
/// Look up the specified function in the module symbol table. Four
/// possibilities:
@@ -344,12 +490,18 @@ public:
/// function with a constantexpr cast to the right prototype.
///
/// In all cases, the returned value is a FunctionCallee wrapper around the
/// 'FunctionType *T' passed in, as well as a 'Value*' either of the Function or
/// the bitcast to the function.
/// 'FunctionType *T' passed in, as well as a 'Value*' either of the Function
/// or the bitcast to the function.
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T,
AttributeList AttributeList);
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T,
AttributeList AttributeList,
unsigned TargetId);
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T);
FunctionCallee getOrInsertFunction(StringRef Name, FunctionType *T,
unsigned TargetId);
/// Look up the specified function in the module symbol table. If it does not
/// exist, add a prototype for the function and return it. This function
@@ -361,12 +513,20 @@ public:
FunctionCallee getOrInsertFunction(StringRef Name,
AttributeList AttributeList, Type *RetTy,
ArgsTy... Args) {
SmallVector<Type*, sizeof...(ArgsTy)> ArgTys{Args...};
return getOrInsertFunction(Name,
FunctionType::get(RetTy, ArgTys, false),
SmallVector<Type *, sizeof...(ArgsTy)> ArgTys{Args...};
return getOrInsertFunction(Name, FunctionType::get(RetTy, ArgTys, false),
AttributeList);
}
template <typename... ArgsTy>
FunctionCallee getOrInsertFunction(StringRef Name,
AttributeList AttributeList, Type *RetTy,
unsigned TargetId, ArgsTy... Args) {
SmallVector<Type *, sizeof...(ArgsTy)> ArgTys{Args...};
return getOrInsertFunction(Name, FunctionType::get(RetTy, ArgTys, false),
AttributeList, TargetId);
}
/// Same as above, but without the attributes.
template <typename... ArgsTy>
FunctionCallee getOrInsertFunction(StringRef Name, Type *RetTy,
@@ -374,6 +534,12 @@ public:
return getOrInsertFunction(Name, AttributeList{}, RetTy, Args...);
}
template <typename... ArgsTy>
FunctionCallee getOrInsertFunction(StringRef Name, Type *RetTy,
unsigned TargetId, ArgsTy... Args) {
return getOrInsertFunction(Name, AttributeList{}, RetTy, TargetId, Args...);
}
// Avoid an incorrect ordering that'd otherwise compile incorrectly.
template <typename... ArgsTy>
FunctionCallee
@@ -384,9 +550,11 @@ public:
/// exist, return null.
Function *getFunction(StringRef Name) const;
/// @}
/// @name Global Variable Accessors
/// @{
Function *getFunction(StringRef Name, unsigned TargetId) const;
/// @}
/// @name Global Variable Accessors
/// @{
/// Look up the specified global variable in the module symbol table. If it
/// does not exist, return null. If AllowInternal is set to true, this
@@ -396,23 +564,44 @@ public:
return getGlobalVariable(Name, false);
}
GlobalVariable *getGlobalVariable(StringRef Name, unsigned TargetId) const {
return getGlobalVariable(Name, false, TargetId);
}
GlobalVariable *getGlobalVariable(StringRef Name, bool AllowInternal) const;
GlobalVariable *getGlobalVariable(StringRef Name, bool AllowInternal,
unsigned TargetId) const;
GlobalVariable *getGlobalVariable(StringRef Name,
bool AllowInternal = false) {
return static_cast<const Module *>(this)->getGlobalVariable(Name,
AllowInternal);
}
GlobalVariable *getGlobalVariable(StringRef Name, unsigned TargetId,
bool AllowInternal = false) {
return static_cast<const Module *>(this)->getGlobalVariable(
Name, AllowInternal, TargetId);
}
/// Return the global variable in the module with the specified name, of
/// arbitrary type. This method returns null if a global with the specified
/// name is not found.
const GlobalVariable *getNamedGlobal(StringRef Name) const {
return getGlobalVariable(Name, true);
}
const GlobalVariable *getNamedGlobal(StringRef Name,
unsigned TargetId) const {
return getGlobalVariable(Name, TargetId, true);
}
GlobalVariable *getNamedGlobal(StringRef Name) {
return const_cast<GlobalVariable *>(
static_cast<const Module *>(this)->getNamedGlobal(Name));
static_cast<const Module *>(this)->getNamedGlobal(Name));
}
GlobalVariable *getNamedGlobal(StringRef Name, unsigned TargetId) {
return const_cast<GlobalVariable *>(
static_cast<const Module *>(this)->getNamedGlobal(Name, TargetId));
}
/// Look up the specified global in the module symbol table.
@@ -422,32 +611,39 @@ public:
Constant *
getOrInsertGlobal(StringRef Name, Type *Ty,
function_ref<GlobalVariable *()> CreateGlobalCallback);
Constant *
getOrInsertGlobal(StringRef Name, Type *Ty,
function_ref<GlobalVariable *()> CreateGlobalCallback,
unsigned TargetId);
/// Look up the specified global in the module symbol table. If required, this
/// overload constructs the global variable using its constructor's defaults.
Constant *getOrInsertGlobal(StringRef Name, Type *Ty);
Constant *getOrInsertGlobal(StringRef Name, Type *Ty, unsigned TargetId);
/// @}
/// @name Global Alias Accessors
/// @{
/// @}
/// @name Global Alias Accessors
/// @{
/// Return the global alias in the module with the specified name, of
/// arbitrary type. This method returns null if a global with the specified
/// name is not found.
GlobalAlias *getNamedAlias(StringRef Name) const;
GlobalAlias *getNamedAlias(StringRef Name, unsigned TargetId) const;
/// @}
/// @name Global IFunc Accessors
/// @{
/// @}
/// @name Global IFunc Accessors
/// @{
/// Return the global ifunc in the module with the specified name, of
/// arbitrary type. This method returns null if a global with the specified
/// name is not found.
GlobalIFunc *getNamedIFunc(StringRef Name) const;
GlobalIFunc *getNamedIFunc(StringRef Name, unsigned TargetId) const;
/// @}
/// @name Named Metadata Accessors
/// @{
/// @}
/// @name Named Metadata Accessors
/// @{
/// Return the first NamedMDNode in the module with the specified name. This
/// method returns null if a NamedMDNode with the specified name is not found.
@@ -461,17 +657,18 @@ public:
/// Remove the given NamedMDNode from this module and delete it.
void eraseNamedMetadata(NamedMDNode *NMD);
/// @}
/// @name Comdat Accessors
/// @{
/// @}
/// @name Comdat Accessors
/// @{
/// Return the Comdat in the module with the specified name. It is created
/// if it didn't already exist.
Comdat *getOrInsertComdat(StringRef Name);
Comdat *getOrInsertComdat(StringRef Name, unsigned TargetId);
/// @}
/// @name Module Flags Accessors
/// @{
/// @}
/// @name Module Flags Accessors
/// @{
/// Returns the module flags in the provided vector.
void getModuleFlagsMetadata(SmallVectorImpl<ModuleFlagEntry> &Flags) const;
@@ -523,74 +720,74 @@ public:
llvm::Error materializeMetadata();
/// @}
/// @name Direct access to the globals list, functions list, and symbol table
/// @{
/// @}
/// @name Direct access to the globals list, functions list, and symbol table
/// @{
/// Get the Module's list of global variables (constant).
const GlobalListType &getGlobalList() const { return GlobalList; }
const GlobalListType &getGlobalList() const { return GlobalList; }
/// Get the Module's list of global variables.
GlobalListType &getGlobalList() { return GlobalList; }
GlobalListType &getGlobalList() { return GlobalList; }
static GlobalListType Module::*getSublistAccess(GlobalVariable*) {
static GlobalListType Module::*getSublistAccess(GlobalVariable *) {
return &Module::GlobalList;
}
/// Get the Module's list of functions (constant).
const FunctionListType &getFunctionList() const { return FunctionList; }
const FunctionListType &getFunctionList() const { return FunctionList; }
/// Get the Module's list of functions.
FunctionListType &getFunctionList() { return FunctionList; }
static FunctionListType Module::*getSublistAccess(Function*) {
FunctionListType &getFunctionList() { return FunctionList; }
static FunctionListType Module::*getSublistAccess(Function *) {
return &Module::FunctionList;
}
/// Get the Module's list of aliases (constant).
const AliasListType &getAliasList() const { return AliasList; }
const AliasListType &getAliasList() const { return AliasList; }
/// Get the Module's list of aliases.
AliasListType &getAliasList() { return AliasList; }
AliasListType &getAliasList() { return AliasList; }
static AliasListType Module::*getSublistAccess(GlobalAlias*) {
static AliasListType Module::*getSublistAccess(GlobalAlias *) {
return &Module::AliasList;
}
/// Get the Module's list of ifuncs (constant).
const IFuncListType &getIFuncList() const { return IFuncList; }
const IFuncListType &getIFuncList() const { return IFuncList; }
/// Get the Module's list of ifuncs.
IFuncListType &getIFuncList() { return IFuncList; }
IFuncListType &getIFuncList() { return IFuncList; }
static IFuncListType Module::*getSublistAccess(GlobalIFunc*) {
static IFuncListType Module::*getSublistAccess(GlobalIFunc *) {
return &Module::IFuncList;
}
/// Get the Module's list of named metadata (constant).
const NamedMDListType &getNamedMDList() const { return NamedMDList; }
const NamedMDListType &getNamedMDList() const { return NamedMDList; }
/// Get the Module's list of named metadata.
NamedMDListType &getNamedMDList() { return NamedMDList; }
NamedMDListType &getNamedMDList() { return NamedMDList; }
static NamedMDListType Module::*getSublistAccess(NamedMDNode*) {
static NamedMDListType Module::*getSublistAccess(NamedMDNode *) {
return &Module::NamedMDList;
}
/// Get the symbol table of global variable and function identifiers
const ValueSymbolTable &getValueSymbolTable() const { return *ValSymTab; }
/// Get the Module's symbol table of global variable and function identifiers.
ValueSymbolTable &getValueSymbolTable() { return *ValSymTab; }
ValueSymbolTable &getValueSymbolTable() { return *ValSymTab; }
/// Get the Module's symbol table for COMDATs (constant).
const ComdatSymTabType &getComdatSymbolTable() const { return ComdatSymTab; }
/// Get the Module's symbol table for COMDATs.
ComdatSymTabType &getComdatSymbolTable() { return ComdatSymTab; }
/// @}
/// @name Global Variable Iteration
/// @{
/// @}
/// @name Global Variable Iteration
/// @{
global_iterator global_begin() { return GlobalList.begin(); }
global_iterator global_begin() { return GlobalList.begin(); }
const_global_iterator global_begin() const { return GlobalList.begin(); }
global_iterator global_end () { return GlobalList.end(); }
const_global_iterator global_end () const { return GlobalList.end(); }
size_t global_size () const { return GlobalList.size(); }
bool global_empty() const { return GlobalList.empty(); }
global_iterator global_end() { return GlobalList.end(); }
const_global_iterator global_end() const { return GlobalList.end(); }
size_t global_size() const { return GlobalList.size(); }
bool global_empty() const { return GlobalList.empty(); }
iterator_range<global_iterator> globals() {
return make_range(global_begin(), global_end());
@@ -599,38 +796,36 @@ public:
return make_range(global_begin(), global_end());
}
/// @}
/// @name Function Iteration
/// @{
/// @}
/// @name Function Iteration
/// @{
iterator begin() { return FunctionList.begin(); }
const_iterator begin() const { return FunctionList.begin(); }
iterator end () { return FunctionList.end(); }
const_iterator end () const { return FunctionList.end(); }
reverse_iterator rbegin() { return FunctionList.rbegin(); }
const_reverse_iterator rbegin() const{ return FunctionList.rbegin(); }
reverse_iterator rend() { return FunctionList.rend(); }
const_reverse_iterator rend() const { return FunctionList.rend(); }
size_t size() const { return FunctionList.size(); }
bool empty() const { return FunctionList.empty(); }
iterator begin() { return FunctionList.begin(); }
const_iterator begin() const { return FunctionList.begin(); }
iterator end() { return FunctionList.end(); }
const_iterator end() const { return FunctionList.end(); }
reverse_iterator rbegin() { return FunctionList.rbegin(); }
const_reverse_iterator rbegin() const { return FunctionList.rbegin(); }
reverse_iterator rend() { return FunctionList.rend(); }
const_reverse_iterator rend() const { return FunctionList.rend(); }
size_t size() const { return FunctionList.size(); }
bool empty() const { return FunctionList.empty(); }
iterator_range<iterator> functions() {
return make_range(begin(), end());
}
iterator_range<iterator> functions() { return make_range(begin(), end()); }
iterator_range<const_iterator> functions() const {
return make_range(begin(), end());
}
/// @}
/// @name Alias Iteration
/// @{
/// @}
/// @name Alias Iteration
/// @{
alias_iterator alias_begin() { return AliasList.begin(); }
const_alias_iterator alias_begin() const { return AliasList.begin(); }
alias_iterator alias_end () { return AliasList.end(); }
const_alias_iterator alias_end () const { return AliasList.end(); }
size_t alias_size () const { return AliasList.size(); }
bool alias_empty() const { return AliasList.empty(); }
alias_iterator alias_begin() { return AliasList.begin(); }
const_alias_iterator alias_begin() const { return AliasList.begin(); }
alias_iterator alias_end() { return AliasList.end(); }
const_alias_iterator alias_end() const { return AliasList.end(); }
size_t alias_size() const { return AliasList.size(); }
bool alias_empty() const { return AliasList.empty(); }
iterator_range<alias_iterator> aliases() {
return make_range(alias_begin(), alias_end());
@@ -639,16 +834,16 @@ public:
return make_range(alias_begin(), alias_end());
}
/// @}
/// @name IFunc Iteration
/// @{
/// @}
/// @name IFunc Iteration
/// @{
ifunc_iterator ifunc_begin() { return IFuncList.begin(); }
const_ifunc_iterator ifunc_begin() const { return IFuncList.begin(); }
ifunc_iterator ifunc_end () { return IFuncList.end(); }
const_ifunc_iterator ifunc_end () const { return IFuncList.end(); }
size_t ifunc_size () const { return IFuncList.size(); }
bool ifunc_empty() const { return IFuncList.empty(); }
ifunc_iterator ifunc_begin() { return IFuncList.begin(); }
const_ifunc_iterator ifunc_begin() const { return IFuncList.begin(); }
ifunc_iterator ifunc_end() { return IFuncList.end(); }
const_ifunc_iterator ifunc_end() const { return IFuncList.end(); }
size_t ifunc_size() const { return IFuncList.size(); }
bool ifunc_empty() const { return IFuncList.empty(); }
iterator_range<ifunc_iterator> ifuncs() {
return make_range(ifunc_begin(), ifunc_end());
@@ -694,7 +889,7 @@ public:
return NamedMDList.end();
}
size_t named_metadata_size() const { return NamedMDList.size(); }
size_t named_metadata_size() const { return NamedMDList.size(); }
bool named_metadata_empty() const { return NamedMDList.empty(); }
iterator_range<named_metadata_iterator> named_metadata() {
@@ -761,7 +956,7 @@ public:
debug_compile_units_iterator(CUs, 0),
debug_compile_units_iterator(CUs, CUs ? CUs->getNumOperands() : 0));
}
/// @}
/// @}
/// Destroy ConstantArrays in LLVMContext if they are not used.
/// ConstantArrays constructed during linking can cause quadratic memory
@@ -772,8 +967,8 @@ public:
/// be called where all uses of the LLVMContext are understood.
void dropTriviallyDeadConstantArrays();
/// @name Utility functions for printing and dumping Module objects
/// @{
/// @name Utility functions for printing and dumping Module objects
/// @{
/// Print the module to an output stream with an optional
/// AssemblyAnnotationWriter. If \c ShouldPreserveUseListOrder, then include
@@ -794,9 +989,9 @@ public:
/// that has "dropped all references", except operator delete.
void dropAllReferences();
/// @}
/// @name Utility functions for querying Debug information.
/// @{
/// @}
/// @name Utility functions for querying Debug information.
/// @{
/// Returns the Number of Register ParametersDwarf Version by checking
/// module flags.
@@ -812,27 +1007,27 @@ public:
/// Returns zero if not present in module.
unsigned getCodeViewFlag() const;
/// @}
/// @name Utility functions for querying and setting PIC level
/// @{
/// @}
/// @name Utility functions for querying and setting PIC level
/// @{
/// Returns the PIC level (small or large model)
PICLevel::Level getPICLevel() const;
/// Set the PIC level (small or large model)
void setPICLevel(PICLevel::Level PL);
/// @}
/// @}
/// @}
/// @name Utility functions for querying and setting PIE level
/// @{
/// @}
/// @name Utility functions for querying and setting PIE level
/// @{
/// Returns the PIE level (small or large model)
PIELevel::Level getPIELevel() const;
/// Set the PIE level (small or large model)
void setPIELevel(PIELevel::Level PL);
/// @}
/// @}
/// @}
/// @name Utility function for querying and setting code model
@@ -886,6 +1081,26 @@ public:
/// Set the partial sample profile ratio in the profile summary module flag,
/// if applicable.
void setPartialSampleProfileRatio(const ModuleSummaryIndex &Index);
/// @}
/// @name Utility functions for heterogenous module
/// @{
/// Get the number of targets this module contains
unsigned getNumTargets() const { return NumTargets; }
/// Set the number of targets this module can cantain
void setNumTargets(unsigned Num) {
assert(IsHeterogenousModule &&
"This function can only be called when the module is heterogenous");
assert(Num <= LLVM_MODULE_NUM_TARGETS && "");
NumTargets = Num;
}
/// Set the active target id
void setActiveTarget(unsigned TargetId) { ActiveTarget = TargetId; }
/// Get the active target id
unsigned getActiveTarget() const { return ActiveTarget; }
};
/// Given "llvm.used" or "llvm.compiler.used" as a global name, collect the
@@ -908,9 +1123,50 @@ DEFINE_SIMPLE_CONVERSION_FUNCTIONS(Module, LLVMModuleRef)
* Module.
*/
inline Module *unwrap(LLVMModuleProviderRef MP) {
return reinterpret_cast<Module*>(MP);
return reinterpret_cast<Module *>(MP);
}
namespace heterogenous {
/// Check if a given global name is a mangled name for heterogenous module.
inline bool isMangledName(StringRef Name) {
return Name.startswith("[target]");
}
/// Check if a given global name is an internal function, e.g. llvm.*.
inline bool isInternalName(StringRef Name) {
return Name.startswith("llvm.") || Name.startswith("nvvm.");
}
/// Return true if the Value needs to be mangled (renamed) when merged to a
/// heterogenous module.
inline bool isMangleNeeded(const llvm::Value *V) {
return !(isInternalName(V->getName()));
}
/// Mangle a name to a heterogenous name.
inline std::string mangleName(StringRef Name, unsigned TargetId) {
return "[target][" + std::to_string(TargetId) + "]" + std::string(Name);
}
/// Demangle a name, and return a pair of target id and the original name.
inline std::pair<unsigned, StringRef> demangleName(StringRef Name) {
assert(isMangledName(Name) && "Name should be mangled");
assert(Name[8] == '[' && "Unrecognized mangling");
unsigned I = 9;
unsigned Id = 0;
while (Name[I] != ']') {
assert(isdigit(Name[I]) && "Unrecognized mangling");
Id = Id * 10 + Name[I] - '0';
++I;
}
assert(Id < LLVM_MODULE_NUM_TARGETS && "Id is larger than maximum value");
return std::make_pair(Id, Name.substr(I + 1));
}
} // namespace heterogenous
} // end namespace llvm
#endif // LLVM_IR_MODULE_H

View File

@@ -0,0 +1,27 @@
//===-- HeterogenousModuleUtils.h -------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines some utility functions for heterognous module.
//
//===----------------------------------------------------------------------===//
#ifndef LLVM_TRANSFORMS_UTILS_HETEROGENOUSMODULEUTILS_H
#define LLVM_TRANSFORMS_UTILS_HETEROGENOUSMODULEUTILS_H
#include "llvm/ADT/StringRef.h"
namespace llvm {
class Module;
namespace heterogenous {
/// Rename symbols in a module before getting merged into a heterogenous module.
void renameModuleSymbols(Module &SrcM, unsigned TargetId);
} // namespace heterogenous
} // namespace llvm
#endif // LLVM_TRANSFORMS_UTILS_HETEROGENOUSMODULEUTILS_H

View File

@@ -443,6 +443,31 @@ bool LLParser::parseTargetDefinition() {
return true;
M->setDataLayout(Str);
return false;
case lltok::APSInt:
assert(Lex.getKind() == lltok::APSInt);
unsigned TargetId = Lex.getAPSIntVal().getZExtValue();
M->markHeterogenous();
// TODO: Use a more robust way to set the number of targets
M->setNumTargets(TargetId + 1);
switch (Lex.Lex()) {
default:
return tokError("unknown target property");
case lltok::kw_triple:
Lex.Lex();
if (parseToken(lltok::equal, "expected '=' after target triple") ||
parseStringConstant(Str))
return true;
M->setTargetTriple(Str, TargetId);
return false;
case lltok::kw_datalayout:
Lex.Lex();
if (parseToken(lltok::equal, "expected '=' after target datalayout") ||
parseStringConstant(Str))
return true;
M->setDataLayout(Str, TargetId);
return false;
}
llvm_unreachable("should never reach here");
}
}

View File

@@ -3506,6 +3506,8 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
SmallVector<uint64_t, 64> Record;
unsigned CurrentTargetId = 0;
// Parts of bitcode parsing depend on the datalayout. Make sure we
// finalize the datalayout before we run any of that code.
bool ResolvedDataLayout = false;
@@ -3516,14 +3518,26 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
// datalayout and triple can't be parsed after this point.
ResolvedDataLayout = true;
// Upgrade data layout string.
std::string DL = llvm::UpgradeDataLayoutString(
TheModule->getDataLayoutStr(), TheModule->getTargetTriple());
TheModule->setDataLayout(DL);
if (TheModule->isHeterogenousModule()) {
// Upgrade data layout string.
std::string DL = llvm::UpgradeDataLayoutString(
TheModule->getDataLayoutStr(CurrentTargetId),
TheModule->getTargetTriple(CurrentTargetId));
TheModule->setDataLayout(DL, CurrentTargetId);
if (auto LayoutOverride =
DataLayoutCallback(TheModule->getTargetTriple()))
TheModule->setDataLayout(*LayoutOverride);
if (auto LayoutOverride =
DataLayoutCallback(TheModule->getTargetTriple(CurrentTargetId)))
TheModule->setDataLayout(*LayoutOverride, CurrentTargetId);
} else {
// Upgrade data layout string.
std::string DL = llvm::UpgradeDataLayoutString(
TheModule->getDataLayoutStr(), TheModule->getTargetTriple());
TheModule->setDataLayout(DL);
if (auto LayoutOverride =
DataLayoutCallback(TheModule->getTargetTriple()))
TheModule->setDataLayout(*LayoutOverride);
}
};
// Read all the records for this module.
@@ -3682,6 +3696,16 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
return MaybeBitCode.takeError();
switch (unsigned BitCode = MaybeBitCode.get()) {
default: break; // Default behavior, ignore unknown content.
case bitc::MODULE_CODE_TARGET_ID: {
std::string S;
if (convertToString(Record, 0, S))
return error("Invalid record");
CurrentTargetId = std::stoul(S);
if (!TheModule->isHeterogenousModule())
TheModule->markHeterogenous();
TheModule->setNumTargets(CurrentTargetId + 1);
break;
}
case bitc::MODULE_CODE_VERSION: {
Expected<unsigned> VersionOrErr = parseVersionRecord(Record);
if (!VersionOrErr)
@@ -3695,7 +3719,10 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
std::string S;
if (convertToString(Record, 0, S))
return error("Invalid record");
TheModule->setTargetTriple(S);
if (TheModule->isHeterogenousModule())
TheModule->setTargetTriple(S, CurrentTargetId);
else
TheModule->setTargetTriple(S);
break;
}
case bitc::MODULE_CODE_DATALAYOUT: { // DATALAYOUT: [strchr x N]
@@ -3704,14 +3731,20 @@ Error BitcodeReader::parseModule(uint64_t ResumeBit,
std::string S;
if (convertToString(Record, 0, S))
return error("Invalid record");
TheModule->setDataLayout(S);
if (TheModule->isHeterogenousModule())
TheModule->setDataLayout(S, CurrentTargetId);
else
TheModule->setDataLayout(S);
break;
}
case bitc::MODULE_CODE_ASM: { // ASM: [strchr x N]
std::string S;
if (convertToString(Record, 0, S))
return error("Invalid record");
TheModule->setModuleInlineAsm(S);
if (TheModule->isHeterogenousModule())
TheModule->setModuleInlineAsm(S, CurrentTargetId);
else
TheModule->setModuleInlineAsm(S);
break;
}
case bitc::MODULE_CODE_DEPLIB: { // DEPLIB: [strchr x N]

View File

@@ -1189,16 +1189,43 @@ static StringEncoding getStringEncoding(StringRef Str) {
/// descriptors for global variables, and function prototype info.
/// Returns the bit offset to backpatch with the location of the real VST.
void ModuleBitcodeWriter::writeModuleInfo() {
// Emit various pieces of data attached to a module.
if (!M.getTargetTriple().empty())
writeStringRecord(Stream, bitc::MODULE_CODE_TRIPLE, M.getTargetTriple(),
0 /*TODO*/);
const std::string &DL = M.getDataLayoutStr();
if (!DL.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_DATALAYOUT, DL, 0 /*TODO*/);
if (!M.getModuleInlineAsm().empty())
writeStringRecord(Stream, bitc::MODULE_CODE_ASM, M.getModuleInlineAsm(),
0 /*TODO*/);
// If a heterogenous module only contains one target, we take it as regular
// module
// FIXME: Finish this part
if (M.isHeterogenousModule() && M.getNumTargets() > 1) {
const auto NumTargets = M.getNumTargets();
writeStringRecord(Stream, bitc::MODULE_CODE_NUM_TARGETS,
std::to_string(NumTargets), 0 /*TODO*/);
for (unsigned I = 0; I < NumTargets; ++I) {
const std::string &TargetTriple = M.getTargetTriple(I);
const std::string &DL = M.getDataLayoutStr(I);
const std::string &ModuleInlineAsm = M.getModuleInlineAsm(I);
if (!TargetTriple.empty() || !DL.empty() || !ModuleInlineAsm.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_TARGET_ID,
std::to_string(I), 0 /*TODO*/);
if (!TargetTriple.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_TRIPLE, TargetTriple,
0 /*TODO*/);
if (!DL.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_DATALAYOUT, DL, 0 /*TODO*/);
if (!ModuleInlineAsm.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_ASM, ModuleInlineAsm,
0 /*TODO*/);
}
} else {
// Emit various pieces of data attached to a module.
if (!M.getTargetTriple().empty())
writeStringRecord(Stream, bitc::MODULE_CODE_TRIPLE, M.getTargetTriple(),
0 /*TODO*/);
const std::string &DL = M.getDataLayoutStr();
if (!DL.empty())
writeStringRecord(Stream, bitc::MODULE_CODE_DATALAYOUT, DL, 0 /*TODO*/);
if (!M.getModuleInlineAsm().empty())
writeStringRecord(Stream, bitc::MODULE_CODE_ASM, M.getModuleInlineAsm(),
0 /*TODO*/);
}
// Emit information about sections and GC, computing how many there are. Also
// compute the maximum alignment value.

View File

@@ -98,17 +98,41 @@ AbstractCallSite::AbstractCallSite(const Use *U)
return;
}
unsigned UseIdx = CB->getArgOperandNo(U);
MDNode *CallbackEncMD = nullptr;
for (const MDOperand &Op : CallbackMD->operands()) {
auto GetCBCalleeIdx = [](const MDOperand &Op) {
MDNode *OpMD = cast<MDNode>(Op.get());
auto *CBCalleeIdxAsCM = cast<ConstantAsMetadata>(OpMD->getOperand(0));
uint64_t CBCalleeIdx =
cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
if (CBCalleeIdx != UseIdx)
continue;
CallbackEncMD = OpMD;
break;
return cast<ConstantInt>(CBCalleeIdxAsCM->getValue())->getZExtValue();
};
if (CB->isBundleOperand(U)) {
// If the Use is a bundle operand, we're constructing a callsite for
// heterogenous callback.
for (const MDOperand &Op : CallbackMD->operands()) {
uint64_t CBCalleeIdx = GetCBCalleeIdx(Op);
assert(CBCalleeIdx < CB->getNumArgOperands());
llvm::Value *CBCallee = CB->getOperand(CBCalleeIdx);
// We suppose the callback functions are stored in operand bundles with
// name same as the callback global variable.
auto OpBundleOption = CB->getOperandBundle(CBCallee->getName());
if (!OpBundleOption)
continue;
if (OpBundleOption->isUseInBundle(U)) {
CallbackEncMD = cast<MDNode>(Op.get());
break;
}
}
} else {
unsigned UseIdx = CB->getArgOperandNo(U);
for (const MDOperand &Op : CallbackMD->operands()) {
uint64_t CBCalleeIdx = GetCBCalleeIdx(Op);
assert(CBCalleeIdx < CB->getNumArgOperands());
if (CBCalleeIdx != UseIdx)
continue;
CallbackEncMD = cast<MDNode>(Op.get());
break;
}
}
if (!CallbackEncMD) {
@@ -148,7 +172,10 @@ AbstractCallSite::AbstractCallSite(const Use *U)
if (VarArgFlagAsCM->getValue()->isNullValue())
return;
// Add all variadic arguments at the end.
// Add all variadic arguments
for (unsigned u = Callee->arg_size(); u < NumCallOperands; u++)
CI.ParameterEncoding.push_back(u);
// Add the operand number of the Use at the end
CI.ParameterEncoding.push_back(U->getOperandNo());
}

View File

@@ -2786,27 +2786,43 @@ void AssemblyWriter::printModule(const Module *M) {
Out << "\"\n";
}
const std::string &DL = M->getDataLayoutStr();
if (!DL.empty())
Out << "target datalayout = \"" << DL << "\"\n";
if (!M->getTargetTriple().empty())
Out << "target triple = \"" << M->getTargetTriple() << "\"\n";
if (M->isHeterogenousModule()) {
for (unsigned I = 0; I < M->getNumTargets(); ++I) {
const std::string TargetId(std::move(std::to_string(I)));
if (!M->getModuleInlineAsm().empty()) {
Out << '\n';
const std::string &DL = M->getDataLayoutStr(I);
if (!DL.empty())
Out << "target " << TargetId << " datalayout = \"" << DL << "\"\n";
// Split the string into lines, to make it easier to read the .ll file.
StringRef Asm = M->getModuleInlineAsm();
do {
StringRef Front;
std::tie(Front, Asm) = Asm.split('\n');
const std::string &T = M->getTargetTriple(I);
if (!T.empty())
Out << "target " << TargetId << " triple = \"" << T << "\"\n";
// We found a newline, print the portion of the asm string from the
// last newline up to this newline.
Out << "module asm \"";
printEscapedString(Front, Out);
Out << "\"\n";
} while (!Asm.empty());
// FIXME: Print inline asm
}
} else {
const std::string &DL = M->getDataLayoutStr();
if (!DL.empty())
Out << "target datalayout = \"" << DL << "\"\n";
if (!M->getTargetTriple().empty())
Out << "target triple = \"" << M->getTargetTriple() << "\"\n";
if (!M->getModuleInlineAsm().empty()) {
Out << '\n';
// Split the string into lines, to make it easier to read the .ll file.
StringRef Asm = M->getModuleInlineAsm();
do {
StringRef Front;
std::tie(Front, Asm) = Asm.split('\n');
// We found a newline, print the portion of the asm string from the
// last newline up to this newline.
Out << "module asm \"";
printEscapedString(Front, Out);
Out << "\"\n";
} while (!Asm.empty());
}
}
printTypeIdentities();

View File

@@ -27,6 +27,7 @@
#include "llvm/IR/Type.h"
#include "llvm/IR/Value.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
@@ -38,6 +39,8 @@
#include <tuple>
#include <utility>
#define DEBUG_TYPE "data-layout"
using namespace llvm;
//===----------------------------------------------------------------------===//
@@ -547,6 +550,51 @@ bool DataLayout::operator==(const DataLayout &Other) const {
return Ret;
}
bool DataLayout::isCompatibleWith(const DataLayout &Other) const {
if (*this == Other)
return true;
bool Partial = BigEndian == Other.BigEndian;
if (StackNaturalAlign.hasValue() && Other.StackNaturalAlign.hasValue())
Partial = Partial && (StackNaturalAlign.getValue() ==
Other.StackNaturalAlign.getValue());
if (FunctionPtrAlign.hasValue() && Other.FunctionPtrAlign.hasValue())
Partial = Partial && (FunctionPtrAlign.getValue() ==
Other.FunctionPtrAlign.getValue());
if (!Partial)
return false;
// Compare the two Alignments. Suppose they're stored w/o any order (O(n^2)).
// Otherwise, this procedure can be optimized to O(n).
for (auto I = Alignments.begin(); I != Alignments.end(); ++I) {
for (auto J = Other.Alignments.begin(); J != Other.Alignments.end(); ++J) {
if (I->AlignType != J->AlignType || I->TypeBitWidth != J->TypeBitWidth)
continue;
// TODO: Do we care about preferred size vs ABI size, or both?
if (!(*I == *J))
return false;
}
}
// Compare pointers. Like before, suppose they're not ordered.
for (auto I = Pointers.begin(); I != Pointers.end(); ++I) {
for (auto J = Other.Pointers.begin(); J != Other.Pointers.end(); ++J) {
if (I->AddressSpace != J->AddressSpace)
continue;
// TODO: Similar question as above
if (!(*I == *J))
return false;
}
}
return true;
}
DataLayout::AlignmentsTy::iterator
DataLayout::findAlignmentLowerBound(AlignTypeEnum AlignType,
uint32_t BitWidth) {

View File

@@ -74,8 +74,13 @@ template class llvm::SymbolTableListTraits<GlobalIFunc>;
Module::Module(StringRef MID, LLVMContext &C)
: Context(C), ValSymTab(std::make_unique<ValueSymbolTable>()),
Materializer(), ModuleID(std::string(MID)),
SourceFileName(std::string(MID)), DL("") {
SourceFileName(std::string(MID)), DL(""), IsHeterogenousModule(false),
NumTargets(0) {
Context.addModule(this);
TargetTriples.resize(LLVM_MODULE_NUM_TARGETS);
GlobalScopeAsms.resize(LLVM_MODULE_NUM_TARGETS);
DLs.resize(LLVM_MODULE_NUM_TARGETS, DataLayout(""));
}
Module::~Module() {
@@ -111,9 +116,25 @@ Module::createRNG(const StringRef Name) const {
/// the specified name, of arbitrary type. This method returns null
/// if a global with the specified name is not found.
GlobalValue *Module::getNamedValue(StringRef Name) const {
// If it is a heterogenous module, the name is not mangled and internal, we
// need to search using mangled name
if (IsHeterogenousModule && !llvm::heterogenous::isInternalName(Name) &&
!llvm::heterogenous::isMangledName(Name))
return getNamedValue(Name, ActiveTarget);
return cast_or_null<GlobalValue>(getValueSymbolTable().lookup(Name));
}
GlobalValue *Module::getNamedValue(StringRef Name, unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return cast_or_null<GlobalValue>(getValueSymbolTable().lookup(
llvm::heterogenous::mangleName(Name, TargetId)));
}
/// getMDKindID - Return a unique non-zero ID for the specified metadata kind.
/// This ID is uniqued across modules in the current LLVMContext.
unsigned Module::getMDKindID(StringRef Name) const {
@@ -148,7 +169,7 @@ FunctionCallee Module::getOrInsertFunction(StringRef Name, FunctionType *Ty,
// Nope, add it
Function *New = Function::Create(Ty, GlobalVariable::ExternalLinkage,
DL.getProgramAddressSpace(), Name);
if (!New->isIntrinsic()) // Intrinsics get attrs set on construction
if (!New->isIntrinsic()) // Intrinsics get attrs set on construction
New->setAttributes(AttributeList);
FunctionList.push_back(New);
return {Ty, New}; // Return the new prototype.
@@ -164,10 +185,31 @@ FunctionCallee Module::getOrInsertFunction(StringRef Name, FunctionType *Ty,
return {Ty, F};
}
FunctionCallee Module::getOrInsertFunction(StringRef Name, FunctionType *Ty,
AttributeList AttributeList,
unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getOrInsertFunction(llvm::heterogenous::mangleName(Name, TargetId), Ty,
AttributeList);
}
FunctionCallee Module::getOrInsertFunction(StringRef Name, FunctionType *Ty) {
return getOrInsertFunction(Name, Ty, AttributeList());
}
FunctionCallee Module::getOrInsertFunction(StringRef Name, FunctionType *Ty,
unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getOrInsertFunction(Name, Ty, AttributeList(), TargetId);
}
// getFunction - Look up the specified function in the module symbol table.
// If it does not exist, return null.
//
@@ -175,6 +217,14 @@ Function *Module::getFunction(StringRef Name) const {
return dyn_cast_or_null<Function>(getNamedValue(Name));
}
Function *Module::getFunction(StringRef Name, unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getFunction(llvm::heterogenous::mangleName(Name, TargetId));
}
//===----------------------------------------------------------------------===//
// Methods for easy access to the global variables in the module.
//
@@ -189,12 +239,22 @@ Function *Module::getFunction(StringRef Name) const {
GlobalVariable *Module::getGlobalVariable(StringRef Name,
bool AllowLocal) const {
if (GlobalVariable *Result =
dyn_cast_or_null<GlobalVariable>(getNamedValue(Name)))
dyn_cast_or_null<GlobalVariable>(getNamedValue(Name)))
if (AllowLocal || !Result->hasLocalLinkage())
return Result;
return nullptr;
}
GlobalVariable *Module::getGlobalVariable(StringRef Name, bool AllowLocal,
unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getGlobalVariable(llvm::heterogenous::mangleName(Name, TargetId),
AllowLocal);
}
/// getOrInsertGlobal - Look up the specified global in the module symbol table.
/// 1. If it does not exist, add a declaration of the global and return it.
/// 2. Else, the global exists but has the wrong type: return the function
@@ -221,6 +281,18 @@ Constant *Module::getOrInsertGlobal(
return GV;
}
Constant *
Module::getOrInsertGlobal(StringRef Name, Type *Ty,
function_ref<GlobalVariable *()> CreateGlobalCallback,
unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getOrInsertGlobal(llvm::heterogenous::mangleName(Name, TargetId), Ty,
CreateGlobalCallback);
}
// Overload to construct a global variable using its constructor's defaults.
Constant *Module::getOrInsertGlobal(StringRef Name, Type *Ty) {
return getOrInsertGlobal(Name, Ty, [&] {
@@ -229,6 +301,15 @@ Constant *Module::getOrInsertGlobal(StringRef Name, Type *Ty) {
});
}
Constant *Module::getOrInsertGlobal(StringRef Name, Type *Ty,
unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return getOrInsertGlobal(llvm::heterogenous::mangleName(Name, TargetId), Ty);
}
//===----------------------------------------------------------------------===//
// Methods for easy access to the global variables in the module.
//
@@ -237,13 +318,37 @@ Constant *Module::getOrInsertGlobal(StringRef Name, Type *Ty) {
// If it does not exist, return null.
//
GlobalAlias *Module::getNamedAlias(StringRef Name) const {
if (IsHeterogenousModule)
return getNamedAlias(Name, ActiveTarget);
return dyn_cast_or_null<GlobalAlias>(getNamedValue(Name));
}
GlobalAlias *Module::getNamedAlias(StringRef Name, unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return dyn_cast_or_null<GlobalAlias>(getNamedValue(Name, TargetId));
}
GlobalIFunc *Module::getNamedIFunc(StringRef Name) const {
if (IsHeterogenousModule)
return getNamedIFunc(Name, ActiveTarget);
return dyn_cast_or_null<GlobalIFunc>(getNamedValue(Name));
}
GlobalIFunc *Module::getNamedIFunc(StringRef Name, unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
return dyn_cast_or_null<GlobalIFunc>(getNamedValue(Name, TargetId));
}
/// getNamedMetadata - Return the first NamedMDNode in the module with the
/// specified name. This method returns null if a NamedMDNode with the
/// specified name is not found.
@@ -299,10 +404,11 @@ bool Module::isValidModuleFlag(const MDNode &ModFlag, ModFlagBehavior &MFB,
}
/// getModuleFlagsMetadata - Returns the module flags in the provided vector.
void Module::
getModuleFlagsMetadata(SmallVectorImpl<ModuleFlagEntry> &Flags) const {
void Module::getModuleFlagsMetadata(
SmallVectorImpl<ModuleFlagEntry> &Flags) const {
const NamedMDNode *ModFlags = getModuleFlagsMetadata();
if (!ModFlags) return;
if (!ModFlags)
return;
for (const MDNode *Flag : ModFlags->operands()) {
ModFlagBehavior MFB;
@@ -388,13 +494,66 @@ void Module::setModuleFlag(ModFlagBehavior Behavior, StringRef Key,
addModuleFlag(Behavior, Key, Val);
}
void Module::setDataLayout(StringRef Desc) {
DL.reset(Desc);
}
void Module::setDataLayout(StringRef Desc) { DL.reset(Desc); }
void Module::setDataLayout(const DataLayout &Other) { DL = Other; }
const DataLayout &Module::getDataLayout() const { return DL; }
void Module::setDataLayout(StringRef Desc, unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return setDataLayout(Desc);
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to set the target "
"layout of non-first target");
DLs[TargetId - 1].reset(Desc);
}
void Module::setDataLayout(const DataLayout &Other, unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return setDataLayout(Other);
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to set the target "
"layout of non-first target");
DLs[TargetId - 1] = Other;
}
const DataLayout &Module::getDataLayout() const {
// If it is a heterogenous module, we need to return the one of active target
if (IsHeterogenousModule)
return getDataLayout(ActiveTarget);
// Otherwise, the original one will be returned
return DL;
}
const DataLayout &Module::getDataLayout(unsigned TargetId) const {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
if (TargetId == 0)
return DL;
assert(IsHeterogenousModule &&
"IsHeterogenousModule should be true if trying to get the data layout "
"of non-first target");
return DLs[TargetId - 1];
}
DICompileUnit *Module::debug_compile_units_iterator::operator*() const {
return cast<DICompileUnit>(CUs->getOperand(Idx));
@@ -529,11 +688,29 @@ unsigned Module::getInstructionCount() {
}
Comdat *Module::getOrInsertComdat(StringRef Name) {
if (IsHeterogenousModule)
return getOrInsertComdat(Name, ActiveTarget);
auto &Entry = *ComdatSymTab.insert(std::make_pair(Name, Comdat())).first;
Entry.second.Name = &Entry;
return &Entry.second;
}
Comdat *Module::getOrInsertComdat(StringRef Name, unsigned TargetId) {
assert(IsHeterogenousModule &&
"This function should be only called when the module is heterogenous");
assert(TargetId < LLVM_MODULE_NUM_TARGETS &&
"TargetId is expected in [0, 31]");
auto &Entry =
*ComdatSymTab
.insert(std::make_pair(
llvm::heterogenous::mangleName(Name, TargetId), Comdat()))
.first;
Entry.second.Name = &Entry;
return &Entry.second;
}
PICLevel::Level Module::getPICLevel() const {
auto *Val = cast_or_null<ConstantAsMetadata>(getModuleFlag("PIC Level"));

View File

@@ -1416,34 +1416,43 @@ Error IRLinker::run() {
if (Error Err = SrcM->getMaterializer()->materializeMetadata())
return Err;
// Inherit the target data from the source module if the destination module
// doesn't have one already.
if (DstM.getDataLayout().isDefault())
DstM.setDataLayout(SrcM->getDataLayout());
if (SrcM->getDataLayout() != DstM.getDataLayout()) {
emitWarning("Linking two modules of different data layouts: '" +
SrcM->getModuleIdentifier() + "' is '" +
SrcM->getDataLayoutStr() + "' whereas '" +
DstM.getModuleIdentifier() + "' is '" +
DstM.getDataLayoutStr() + "'\n");
}
// Copy the target triple from the source to dest if the dest's is empty.
if (DstM.getTargetTriple().empty() && !SrcM->getTargetTriple().empty())
DstM.setTargetTriple(SrcM->getTargetTriple());
const auto IsHeterogenousModule = DstM.isHeterogenousModule();
Triple SrcTriple(SrcM->getTargetTriple()), DstTriple(DstM.getTargetTriple());
if (!SrcM->getTargetTriple().empty()&&
!SrcTriple.isCompatibleWith(DstTriple))
emitWarning("Linking two modules of different target triples: '" +
SrcM->getModuleIdentifier() + "' is '" +
SrcM->getTargetTriple() + "' whereas '" +
DstM.getModuleIdentifier() + "' is '" + DstM.getTargetTriple() +
"'\n");
// If it is a heterogenous module, we need to do something different
if (IsHeterogenousModule) {
const auto CurrentTargetId = DstM.getActiveTarget();
DstM.setDataLayout(SrcM->getDataLayout(), CurrentTargetId);
DstM.setTargetTriple(SrcTriple.getTriple(), CurrentTargetId);
} else {
// Inherit the target data from the source module if the destination module
// doesn't have one already.
if (DstM.getDataLayout().isDefault())
DstM.setDataLayout(SrcM->getDataLayout());
DstM.setTargetTriple(SrcTriple.merge(DstTriple));
if (SrcM->getDataLayout() != DstM.getDataLayout()) {
emitWarning("Linking two modules of different data layouts: '" +
SrcM->getModuleIdentifier() + "' is '" +
SrcM->getDataLayoutStr() + "' whereas '" +
DstM.getModuleIdentifier() + "' is '" +
DstM.getDataLayoutStr() + "'\n");
}
// Copy the target triple from the source to dest if the dest's is empty.
if (DstM.getTargetTriple().empty() && !SrcM->getTargetTriple().empty())
DstM.setTargetTriple(SrcM->getTargetTriple());
if (!SrcM->getTargetTriple().empty() &&
!SrcTriple.isCompatibleWith(DstTriple))
emitWarning("Linking two modules of different target triples: '" +
SrcM->getModuleIdentifier() + "' is '" +
SrcM->getTargetTriple() + "' whereas '" +
DstM.getModuleIdentifier() + "' is '" +
DstM.getTargetTriple() + "'\n");
DstM.setTargetTriple(SrcTriple.merge(DstTriple));
}
// Loop over all of the linked values to compute type mappings.
computeTypeMapping();
@@ -1475,10 +1484,14 @@ Error IRLinker::run() {
// are properly remapped.
linkNamedMDNodes();
if (!IsPerformingImport && !SrcM->getModuleInlineAsm().empty()) {
// Append the module inline asm string.
DstM.appendModuleInlineAsm(adjustInlineAsm(SrcM->getModuleInlineAsm(),
SrcTriple));
if (!IsPerformingImport) {
if (IsHeterogenousModule)
DstM.appendModuleInlineAsm(SrcM->getModuleInlineAsm(),
DstM.getActiveTarget());
else if (!SrcM->getModuleInlineAsm().empty())
// Append the module inline asm string.
DstM.appendModuleInlineAsm(
adjustInlineAsm(SrcM->getModuleInlineAsm(), SrcTriple));
} else if (IsPerformingImport) {
// Import any symver directives for symbols in DstM.
ModuleSymbolTable::CollectAsmSymvers(*SrcM,

View File

@@ -20,6 +20,7 @@
#include "llvm/IR/Module.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/Error.h"
#include "llvm/Transforms/Utils/HeterogenousModuleUtils.h"
using namespace llvm;
namespace {
@@ -569,11 +570,53 @@ bool ModuleLinker::run() {
return false;
}
namespace {
/// Helper class to manage target number when linking a new module into a
/// heterogenous module.
struct HeterogenousModuleLinkerHelper {
Module &DstM;
Module &SrcM;
public:
HeterogenousModuleLinkerHelper(Module &Dst, Module &Src)
: DstM(Dst), SrcM(Src) {
assert(DstM.isHeterogenousModule());
const unsigned NumTargets = DstM.getNumTargets();
DstM.setActiveTarget(NumTargets);
llvm::heterogenous::renameModuleSymbols(Src, NumTargets);
}
/// After the link, set the active target to 0
~HeterogenousModuleLinkerHelper() {
DstM.setNumTargets(DstM.getNumTargets() + 1);
DstM.setActiveTarget(0);
}
/// Return true if the two modules have compatible data layout
bool isCompatible() const {
return DstM.getDataLayout(0).isCompatibleWith(SrcM.getDataLayout());
}
};
} // namespace
Linker::Linker(Module &M) : Mover(M) {}
bool Linker::linkInModule(
std::unique_ptr<Module> Src, unsigned Flags,
std::function<void(Module &, const StringSet<> &)> InternalizeCallback) {
Module &DstM = Mover.getModule();
std::unique_ptr<HeterogenousModuleLinkerHelper> Helper;
if (DstM.isHeterogenousModule()) {
Helper = std::make_unique<HeterogenousModuleLinkerHelper>(DstM, *Src);
// We skip the check for the first module
if (DstM.getNumTargets() && !Helper->isCompatible()) {
DstM.getContext().diagnose(LinkDiagnosticInfo(
DS_Error, "Uncompatible modules cannot be linked together"));
return true;
}
}
ModuleLinker ModLinker(Mover, std::move(Src), Flags,
std::move(InternalizeCallback));
return ModLinker.run();

View File

@@ -336,7 +336,10 @@ struct OMPInformationCache : public InformationCache {
void initializeRuntimeFunctions() {
Module &M = *((*ModuleSlice.begin())->getParent());
// Helper macros for handling __VA_ARGS__ in OMP_RTL
for (unsigned I = 0; I < M.getNumTargets(); ++I) {
M.setActiveTarget(I);
// Helper macros for handling __VA_ARGS__ in OMP_RTL
#define OMP_TYPE(VarName, ...) \
Type *VarName = OMPBuilder.VarName; \
(void)VarName;
@@ -384,6 +387,7 @@ struct OMPInformationCache : public InformationCache {
} \
}
#include "llvm/Frontend/OpenMP/OMPKinds.def"
}
// TODO: We should attach the attributes defined in OMPKinds.def.
}
@@ -510,6 +514,9 @@ struct OpenMPOpt {
if (PrintOpenMPKernels)
printKernels();
if (M.isHeterogenousModule())
Changed |= addCallBackOperandBundles();
Changed |= rewriteDeviceCodeStateMachine();
Changed |= runAttributor();
@@ -1503,6 +1510,8 @@ private:
/// the cases we can avoid taking the address of a function.
bool rewriteDeviceCodeStateMachine();
bool addCallBackOperandBundles();
///
///}}
@@ -1804,6 +1813,114 @@ bool OpenMPOpt::rewriteDeviceCodeStateMachine() {
return Changed;
}
bool OpenMPOpt::addCallBackOperandBundles() {
assert(M.isHeterogenousModule());
bool Changed = false;
SmallVector<std::pair<CallInst *, int64_t>, 8> WorkList;
auto FindCallBackIdx = [&WorkList](llvm::Use &U, llvm::Function &) {
CallInst *CI = dyn_cast<CallInst>(U.getUser());
if (!CI)
return false;
// FIXME: Temporary solution
WorkList.push_back(std::make_pair(CI, 2));
// Function *Callee = CI->getCalledFunction();
// MDNode *CallbackMD = Callee->getMetadata(LLVMContext::MD_callback);
// if (!CallbackMD)
// return false;
// MDNode *CallbackEncMD = cast<MDNode>(CallbackMD->getOperand(0));
// assert(CallbackEncMD && "Empty callback metadata");
// assert(CallbackEncMD->getNumOperands() >= 2 &&
// "Incomplete !callback metadata");
// Metadata *OpAsM = CallbackEncMD->getOperand(0).get();
// auto *OpAsCM = cast<ConstantAsMetadata>(OpAsM);
// assert(OpAsCM->getType()->isIntegerTy(64) &&
// "Malformed !callback metadata");
// int64_t Idx = cast<ConstantInt>(OpAsCM->getValue())->getSExtValue();
// assert(Idx >= 0 && Idx < CI->getNumArgOperands() &&
// "Out-of-bounds !callback metadata index");
// WorkList.push_back(std::make_pair(CI, Idx));
return false;
};
RuntimeFunction TargetRuntimeCallIDs[] = {
OMPRTL___tgt_target_mapper, OMPRTL___tgt_target_nowait_mapper,
OMPRTL___tgt_target_teams_mapper,
OMPRTL___tgt_target_teams_nowait_mapper};
for (auto TargetRuntimeCallID : TargetRuntimeCallIDs) {
OMPInformationCache::RuntimeFunctionInfo &RFI =
OMPInfoCache.RFIs[TargetRuntimeCallID];
if (!RFI)
continue;
for (Function *F : SCC)
RFI.foreachUse(FindCallBackIdx, F);
}
if (WorkList.empty())
return Changed;
for (auto &WorkItem : WorkList) {
CallInst *CI = WorkItem.first;
int64_t Idx = WorkItem.second;
Value *KernelIdOp = CI->getArgOperand(Idx);
assert(KernelIdOp && "KernelIdOp shoule not be null");
GlobalVariable *KernelIdVar = dyn_cast<GlobalVariable>(KernelIdOp);
assert(KernelIdVar && "KernelIdOp should be a global variable");
// Extract kernel function name from kernel id. Kernel id is composed by:
// 1. A leading dot ".";
// 2. Kernel function name;
// 3. Tailing ".region_id"
// For example, if we have a kernel function named "kernel_func", then the
// kernel id, which will be a global variable valued 0, will be named
// ".kernel_func.region_id".
// Therefore, now given the kernel id, we can get the kernel function name
// by removing the leading dot and tailing ".region_id".
StringRef KernelId =
llvm::heterogenous::demangleName(KernelIdVar->getName()).second;
StringRef KernelFuncName = KernelId.substr(1, KernelId.size() - 11);
// We might have multiple kernel funcions with the same original name in a
// heterogenous module. We need to get all of them.
SmallVector<Value *, 4> Inputs;
for (unsigned I = 0; I < M.getNumTargets(); ++I) {
// It is possible that on a specific target, the kernel doesn't exist,
// e.g. when `declare variant` is being used.
if (Function *F = M.getFunction(KernelFuncName, I))
Inputs.push_back(F);
}
// If there is no kernel found, just skip current CI.
if (Inputs.empty())
continue;
if (!Changed)
Changed = true;
// Corresponding kernel functions are encoded into an operand bundle with
// the name same as KernelId.
OperandBundleDef OB(std::string(KernelId), Inputs);
auto *NewCI = CallInst::Create(CI, OB, CI);
CI->replaceAllUsesWith(NewCI);
CI->eraseFromParent();
}
return Changed;
}
/// Abstract Attribute for tracking ICV values.
struct AAICVTracker : public StateWrapper<BooleanState, AbstractAttribute> {
using Base = StateWrapper<BooleanState, AbstractAttribute>;

View File

@@ -29,6 +29,7 @@ add_llvm_component_library(LLVMTransformUtils
GlobalStatus.cpp
GuardUtils.cpp
HelloWorld.cpp
HeterogenousModuleUtils.cpp
InlineFunction.cpp
InjectTLIMappings.cpp
InstructionNamer.cpp

View File

@@ -0,0 +1,54 @@
//===- HeterogenousModuleUtils.cpp - Utilities for heterogenous module -======//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file implements some utility functions for heterogenous module.
//
//===----------------------------------------------------------------------===//
#include "llvm/Transforms/Utils/HeterogenousModuleUtils.h"
#include "llvm/IR/Module.h"
using namespace llvm;
namespace llvm {
namespace heterogenous {
void renameModuleSymbols(Module &SrcM, unsigned TargetId) {
// Global variable
for (GlobalVariable &GV : SrcM.globals()) {
if (!isMangleNeeded(&GV))
continue;
GV.setName(mangleName(GV.getName(), TargetId));
if (Comdat *C = GV.getComdat()) {
Comdat *NC = SrcM.getOrInsertComdat(GV.getName());
NC->setSelectionKind(C->getSelectionKind());
GV.setComdat(NC);
}
}
// Function
for (Function &GV : SrcM.functions()) {
if (!isMangleNeeded(&GV))
continue;
GV.setName(mangleName(GV.getName(), TargetId));
if (Comdat *C = GV.getComdat()) {
Comdat *NC = SrcM.getOrInsertComdat(GV.getName());
NC->setSelectionKind(C->getSelectionKind());
GV.setComdat(NC);
}
}
// Alias
for (GlobalAlias &GV : SrcM.aliases())
GV.setName(mangleName(GV.getName(), TargetId));
// IFunc
for (GlobalIFunc &GV : SrcM.ifuncs())
GV.setName(mangleName(GV.getName(), TargetId));
}
} // namespace heterogenous
} // namespace llvm

View File

@@ -110,6 +110,11 @@ static cl::opt<bool> PreserveAssemblyUseListOrder(
cl::desc("Preserve use-list order when writing LLVM assembly."),
cl::init(false), cl::Hidden);
static cl::opt<bool>
HeterogenousModule("heterogenous-module",
cl::desc("Generate a heterogenous module."),
cl::init(false), cl::Hidden);
static ExitOnError ExitOnErr;
// Read the specified bitcode file in and return it. This routine searches the
@@ -439,6 +444,8 @@ int main(int argc, char **argv) {
Context.enableDebugTypeODRUniquing();
auto Composite = std::make_unique<Module>("llvm-link", Context);
if (HeterogenousModule)
Composite->markHeterogenous();
Linker L(*Composite);
unsigned Flags = Linker::Flags::None;

View File

@@ -53,3 +53,51 @@ TEST(AbstractCallSite, CallbackCall) {
EXPECT_TRUE(ACS.isCallee(CallbackUse));
EXPECT_EQ(ACS.getCalledFunction(), Callback);
}
TEST(AbstractCallSite, HeterogenousCallbackCall) {
LLVMContext C;
const char *IR =
"@region_id = weak constant i8 0\n"
"define void @callback1() {\n"
" ret void\n"
"}\n"
"define void @callback2() {\n"
" ret void\n"
"}\n"
"define void @foo(i32* %A) {\n"
" call void (i32, i8*, ...) @broker(i32 1, i8* @region_id, i32* %A)[\"region_id\"(void ()* @callback1, void ()* @callback2)]\n"
" ret void\n"
"}\n"
"declare !callback !0 void @broker(i32, i8*, ...)\n"
"!0 = !{!1}\n"
"!1 = !{i64 1, i64 -1, i1 true}";
std::unique_ptr<Module> M = parseIR(C, IR);
ASSERT_TRUE(M);
Function *CB1 = M->getFunction("callback1");
Function *CB2 = M->getFunction("callback2");
ASSERT_NE(CB1, nullptr);
ASSERT_NE(CB2, nullptr);
const Use *CB1Use = CB1->getSingleUndroppableUse();
const Use *CB2Use = CB2->getSingleUndroppableUse();
ASSERT_NE(CB1Use, nullptr);
ASSERT_NE(CB2Use, nullptr);
AbstractCallSite ACS1(CB1Use);
AbstractCallSite ACS2(CB2Use);
EXPECT_TRUE(ACS1);
EXPECT_TRUE(ACS2);
EXPECT_TRUE(ACS1.isCallbackCall());
EXPECT_TRUE(ACS2.isCallbackCall());
EXPECT_TRUE(ACS1.isCallee(CB1Use));
EXPECT_TRUE(!ACS1.isCallee(CB2Use));
EXPECT_TRUE(ACS2.isCallee(CB2Use));
EXPECT_TRUE(!ACS2.isCallee(CB1Use));
EXPECT_EQ(ACS1.getCalledFunction(), CB1);
EXPECT_NE(ACS1.getCalledFunction(), CB2);
EXPECT_EQ(ACS2.getCalledFunction(), CB2);
EXPECT_NE(ACS2.getCalledFunction(), CB1);
}