Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions include/infinicore/adaptor/aten_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <ATen/ATen.h>

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
Expand All @@ -31,8 +31,8 @@ inline at::ScalarType to_at_dtype(DataType dtype) {

inline at::Device to_at_device(const Device &device) {
// PyTorch ATen only exposes standard device types (e.g. kCPU/kCUDA).
// Treat MetaX/QY devices as CUDA devices for ATen tensor interoperability.
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY) {
// Treat CUDA-compatible vendor devices as CUDA devices for ATen tensor interoperability.
if (device.getType() == Device::Type::NVIDIA || device.getType() == Device::Type::METAX || device.getType() == Device::Type::QY || device.getType() == Device::Type::ILUVATAR) {
return at::Device(at::kCUDA, device.getIndex());
} else if (device.getType() == Device::Type::CPU) {
return at::Device(at::kCPU);
Expand All @@ -43,7 +43,7 @@ inline at::Device to_at_device(const Device &device) {

at::Tensor to_aten_tensor(const infinicore::Tensor &t);

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
c10::cuda::CUDAStream get_cuda_stream();
#endif
} // namespace infinicore::adaptor
Expand Down
19 changes: 12 additions & 7 deletions include/infinicore/adaptor/flash_attention_adaptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
#pragma once
#include "aten_adaptor.hpp"

// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX flash_attn_2_cuda extension
// exports the same entry points at global scope (no namespace), matching FLASH_NAMESPACE builds
// where the namespace is empty.
#if !defined(ENABLE_METAX_API)
// NVIDIA flash-attn-nvidia.so uses namespace flash. The pip/MetaX/Iluvatar
// flash_attn_2_cuda extension exports the same entry points at global scope
// (no namespace), matching FLASH_NAMESPACE builds where the namespace is empty.
#if !defined(ENABLE_METAX_API) && !defined(ENABLE_ILUVATAR_API)
namespace flash {
#endif
std::vector<at::Tensor>
Expand Down Expand Up @@ -44,13 +44,18 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_hea
int window_size_right,
const float softcap,
const bool return_softmax,
#if defined(ENABLE_ILUVATAR_API)
const bool deterministic,
int sm_margin,
int max_seqlen_k_new,
#endif
std::optional<at::Generator> gen_
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
// MetaX/Mars `flash_attn_2_cuda` (e.g. 2.6.x+mars) appends this argument vs upstream Dao-AILab flash-attn.
,
std::optional<at::Tensor> &flash_attn_mars_ext_
#endif
);
);

std::vector<at::Tensor>
mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multiple_of(head_size_og, 8)
Expand Down Expand Up @@ -125,9 +130,9 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size
,
std::optional<at::Tensor> &flash_attn_mars_ext_
#endif
);
);

#if !defined(ENABLE_METAX_API)
#if !defined(ENABLE_METAX_API) && !defined(ENABLE_ILUVATAR_API)
} // namespace flash
#endif
#endif // ENABLE_FLASH_ATTN
2 changes: 1 addition & 1 deletion src/infinicore/adaptor/aten_adaptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ at::Tensor to_aten_tensor(const infinicore::Tensor &t) {
options);
}

#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
c10::cuda::CUDAStream get_cuda_stream() {
return c10::cuda::getStreamFromExternal(
cudaStream_t(infinicore::context::getStream()), infinicore::context::getDevice().getIndex());
Expand Down
8 changes: 4 additions & 4 deletions src/infinicore/ops/mha_kvcache/mha_kvcache_flashattn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@
#include <stdexcept>

#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#endif

#if defined(ENABLE_METAX_API)
#if defined(ENABLE_METAX_API) || defined(ENABLE_ILUVATAR_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
Expand Down Expand Up @@ -45,7 +45,7 @@ void *plan(Tensor out,

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);
Expand All @@ -55,7 +55,7 @@ void run(void *planned_meta) {
Tensor out_work = out_need_copy_back ? p->out->contiguous() : Tensor(p->out);
auto out_tensor = infinicore::adaptor::to_aten_tensor(out_work);
auto q = infinicore::adaptor::to_aten_tensor(p->q);
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_ILUVATAR_API)
auto k_cache = infinicore::adaptor::to_aten_tensor(p->k_cache);
auto v_cache = infinicore::adaptor::to_aten_tensor(p->v_cache);
#elif defined(ENABLE_QY_API)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <stdexcept>

#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
#include <c10/cuda/CUDAGuard.h>
#endif
#endif
Expand Down Expand Up @@ -48,9 +48,9 @@ void *plan(Tensor out,
namespace {

#ifdef ENABLE_FLASH_ATTN
// MetaX/hpcc pip `flash_attn_2_cuda` exports `mha_varlen_fwd` at global scope (no namespace),
// MetaX/hpcc and Iluvatar pip `flash_attn_2_cuda` export `mha_varlen_fwd` at global scope (no namespace),
// while NVIDIA `flash-attn-nvidia.so` uses `flash::mha_varlen_fwd`.
#if defined(ENABLE_METAX_API)
#if defined(ENABLE_METAX_API) || defined(ENABLE_ILUVATAR_API)
#define INFINICORE_FLASH_OP(name) ::name
#else
#define INFINICORE_FLASH_OP(name) flash::name
Expand All @@ -61,7 +61,9 @@ namespace {

void run(void *planned_meta) {
#ifdef ENABLE_FLASH_ATTN
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_METAX_API) || defined(ENABLE_QY_API) || defined(ENABLE_ILUVATAR_API)
c10::cuda::CUDAStreamGuard guard(infinicore::adaptor::get_cuda_stream());
#endif
auto *p = reinterpret_cast<PlannedMeta *>(planned_meta);

auto q = infinicore::adaptor::to_aten_tensor(p->q);
Expand Down Expand Up @@ -109,6 +111,11 @@ void run(void *planned_meta) {
-1,
0.0,
false,
#if defined(ENABLE_ILUVATAR_API)
false,
0,
0,
#endif
std::nullopt
#if defined(ENABLE_METAX_API) && defined(INFINICORE_HPCC_VERSION_MAJOR) && (INFINICORE_HPCC_VERSION_MAJOR >= 3)
,
Expand Down
50 changes: 47 additions & 3 deletions xmake.lua
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,17 @@ target("infinicore_cpp_api")
if has_config("metax-gpu") then
add_deps("flash-attn-metax")
end
if has_config("iluvatar-gpu") then
add_deps("flash-attn-iluvatar")
local flash_attn_root = get_config("flash-attn")
local flash_attn_files = os.files(path.join(flash_attn_root, "flash_attn_2_cuda*.so"))
if not flash_attn_files or #flash_attn_files == 0 then
raise("iluvatar+flash-attn: cannot locate flash_attn_2_cuda under " .. flash_attn_root)
end
local flash_so = flash_attn_files[1]
local flash_dir = path.directory(flash_so)
add_shflags("-Wl,--no-as-needed", flash_so, "-Wl,-rpath," .. flash_dir, {force = true})
end
if has_config("qy-gpu") then
add_deps("flash-attn-qy")
end
Expand Down Expand Up @@ -515,13 +526,19 @@ target("infinicore_cpp_api")
if has_config("aten") then
local outdata = os.iorunv("python", {"-c", "import torch, os; print(os.path.dirname(torch.__file__))"}):trim()
local TORCH_DIR = outdata
local TORCH_CXX11_ABI = os.iorunv("python", {"-c", "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"}):trim()
local TORCH_ABI_DEFINE = "_GLIBCXX_USE_CXX11_ABI=" .. TORCH_CXX11_ABI

target:add("defines", TORCH_ABI_DEFINE)
target:add("cxflags", "-D" .. TORCH_ABI_DEFINE)
target:add("cxxflags", "-D" .. TORCH_ABI_DEFINE)

target:add(
"includedirs",
path.join(TORCH_DIR, "include"),
"includedirs",
path.join(TORCH_DIR, "include"),
path.join(TORCH_DIR, "include/torch/csrc/api/include"),
{ public = true })

target:add(
"linkdirs",
path.join(TORCH_DIR, "lib"),
Expand All @@ -535,6 +552,22 @@ target("infinicore_cpp_api")
"c10_cuda",
{ public = true }
)

-- Add CUDA/CoreX runtime headers for ATen headers like c10/cuda/CUDAStream.h
local CUDA_HOME = os.getenv("CUDA_HOME") or os.getenv("CUDA_ROOT") or os.getenv("CUDA_PATH")
local COREX_HOME = os.getenv("COREX_HOME") or "/usr/local/corex"

if CUDA_HOME and os.isdir(path.join(CUDA_HOME, "include")) then
target:add("includedirs", path.join(CUDA_HOME, "include"), { public = true })
end

if COREX_HOME and os.isdir(path.join(COREX_HOME, "include")) then
target:add("includedirs", path.join(COREX_HOME, "include"), { public = true })
end

if COREX_HOME and os.isdir(path.join(COREX_HOME, "targets/x86_64-linux/include")) then
target:add("includedirs", path.join(COREX_HOME, "targets/x86_64-linux/include"), { public = true })
end
end

end)
Expand Down Expand Up @@ -572,6 +605,17 @@ target("_infinicore")
add_defines("BOOST_STACKTRACE_USE_NOOP")
end

if has_config("aten") then
before_build(function (target)
local TORCH_CXX11_ABI = os.iorunv("python", {"-c", "import torch; print(int(torch._C._GLIBCXX_USE_CXX11_ABI))"}):trim()
local TORCH_ABI_DEFINE = "_GLIBCXX_USE_CXX11_ABI=" .. TORCH_CXX11_ABI

target:add("defines", TORCH_ABI_DEFINE)
target:add("cxflags", "-D" .. TORCH_ABI_DEFINE)
target:add("cxxflags", "-D" .. TORCH_ABI_DEFINE)
end)
end

set_default(false)
add_rules("python.library", {soabi = true})
add_packages("pybind11")
Expand Down
60 changes: 60 additions & 0 deletions xmake/iluvatar.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local iluvatar_arch = get_config("iluvatar_arch") or "ivcore20"
local FLASH_ATTN_ROOT = get_config("flash-attn")

toolchain("iluvatar.toolchain")
set_toolset("cc" , "clang" )
Expand Down Expand Up @@ -117,3 +118,62 @@ target("infiniccl-iluvatar")
add_files("../src/infiniccl/cuda/*.cu")
end
target_end()

local function iluvatar_flash_attn_cuda_so_path()
local env_path = os.getenv("FLASH_ATTN_2_CUDA_SO")
if env_path and env_path ~= "" then
env_path = env_path:trim()
if os.isfile(env_path) then
return env_path
end
print(string.format(
"warning: iluvatar+flash-attn: FLASH_ATTN_2_CUDA_SO is not a file: %s, fallback to default path",
env_path
))
end

if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
local files = os.files(path.join(FLASH_ATTN_ROOT, "flash_attn_2_cuda*.so"))
if files and #files > 0 then
return files[1]
end
end

local container_path = os.getenv("FLASH_ATTN_ILUVATAR_CUDA_SO_CONTAINER")
if container_path and container_path ~= "" and os.isfile(container_path) then
return container_path:trim()
end

raise("iluvatar+flash-attn: cannot locate flash_attn_2_cuda; install it in current Python env or export FLASH_ATTN_2_CUDA_SO")
end

target("flash-attn-iluvatar")
set_kind("phony")
set_default(false)

if FLASH_ATTN_ROOT and FLASH_ATTN_ROOT ~= "" then
before_build(function (target)
local TORCH_DIR = os.iorunv("python", {
"-c", "import torch, os; print(os.path.dirname(torch.__file__))"
}):trim()
local PYTHON_INCLUDE = os.iorunv("python", {
"-c", "import sysconfig; print(sysconfig.get_paths()['include'])"
}):trim()
local PYTHON_LIB_DIR = os.iorunv("python", {
"-c", "import sysconfig; print(sysconfig.get_config_var('LIBDIR'))"
}):trim()

target:add("includedirs",
TORCH_DIR .. "/include",
TORCH_DIR .. "/include/torch/csrc/api/include",
PYTHON_INCLUDE,
{public = false}
)
target:add("linkdirs", TORCH_DIR .. "/lib", PYTHON_LIB_DIR, {public = false})
end)
else
before_build(function (target)
print("Flash Attention not available, skipping flash-attn-iluvatar integration")
end)
end
target_end()
Loading