Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

speedup ROCm AMD Unified Memory Architecture #7399

Open
Djip007 opened this issue May 19, 2024 · 29 comments
Open

speedup ROCm AMD Unified Memory Architecture #7399

Djip007 opened this issue May 19, 2024 · 29 comments
Labels
enhancement New feature or request

Comments

@Djip007
Copy link
Contributor

Djip007 commented May 19, 2024

GGML_HIP_UMA allow to use hipMallocManaged tu use UMA on AMD/HIP GPU.

I have a Ryzen 7940HS an made some test. Using UMA allow use much more memorie than reserved VRAM on the igpu and it is nice. It allow some speed up over CPU. But by default it use "Fine-grained" memorie that is "slow". If we can use "Coarse-grained" we can have more speed.

What I did for test is replace:

#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif

by:

#ifdef GGML_HIP_UMA
template<typename T>
static inline auto gpuAlloc(T** adr, size_t size) {
    auto res = hipMallocManaged(adr, size);
    if (res == hipSuccess) {
        return hipMemAdvise (*adr, size, hipMemAdviseSetCoarseGrain, 0);
    }
    return res;
}
#define cudaMalloc gpuAlloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif

I test with llamafile 8.4 source code (sorry I made some other test with it like add option to add option with UMA...)
With "mistral-7b-instruct-v0.2.Q8_0.gguf" on a "Framwork 16 with 7940HS / no dGPU + fedora 40 (rocm6)" I have this result:

// - zen4 x 8 => CPU
llama_print_timings:        load time =    1383.08 ms
llama_print_timings:      sample time =       4.02 ms /   534 runs   (    0.01 ms per token, 132835.82 tokens per second)
llama_print_timings: prompt eval time =   29762.54 ms /  1466 tokens (   20.30 ms per token,    49.26 tokens per second)
llama_print_timings:        eval time =   82369.69 ms /   533 runs   (  154.54 ms per token,     6.47 tokens per second)
llama_print_timings:       total time =  112190.80 ms /  1999 tokens

// - gfx1103 / rocblas / HSA_OVERRIDE_GFX_VERSION=11.0.1 
llama_print_timings:        load time =    5391.01 ms
llama_print_timings:      sample time =       2.84 ms /   406 runs   (    0.01 ms per token, 143058.49 tokens per second)
llama_print_timings: prompt eval time =   14886.53 ms /  1466 tokens (   10.15 ms per token,    98.48 tokens per second)
llama_print_timings:        eval time =   67061.92 ms /   405 runs   (  165.58 ms per token,     6.04 tokens per second)
llama_print_timings:       total time =   82020.34 ms /  1871 tokens

// - gfx1103 / rocblas / HSA_OVERRIDE_GFX_VERSION=11.0.1 + "hipMemAdviseSetCoarseGrain"
llama_print_timings:        load time =    5470.93 ms
llama_print_timings:      sample time =       2.94 ms /   121 runs   (    0.02 ms per token, 41212.53 tokens per second)
llama_print_timings: prompt eval time =    8093.41 ms /  1466 tokens (    5.52 ms per token,   181.14 tokens per second)
llama_print_timings:        eval time =   12917.88 ms /   120 runs   (  107.65 ms per token,     9.29 tokens per second)
llama_print_timings:       total time =   21095.91 ms /  1586 tokens

As you can see I get x2 on GPU with "CoarseGrain" and x4 from CPU for prompt eval.
And event +40% for prompte processing ...

Note: to be fare I have crache from time to time on GPU... but that an other story and this gfx1103 is not fully supported (I don't know if crache is because of rocm 6.0 rebuild by Fedora or existe on other rocm ...

Ps: I need to use 'HSA_OVERRIDE_GFX_VERSION=11.0.1' env var... it is the faster than 1100 and 1102 until the 1103 is available.

ps2: sorry for my bad english ...

@Djip007 Djip007 added the enhancement New feature or request label May 19, 2024
@Djip007
Copy link
Contributor Author

Djip007 commented May 19, 2024

I do not made a PR for it because as you see:

hipMemAdvise (*adr, size, hipMemAdviseSetCoarseGrain, 0);

There is a last parameter that is the "ID" of the GPU... I don't now how to get the "good" one if more than 1 GPU is possible and do not have such config to test...

https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___memory_m.html#ga5c8a3ea8a8702747588082ed39ea51bf

@slaren
Copy link
Collaborator

slaren commented May 19, 2024

The device id is available in ggml_backend_cuda_buffer_type_alloc_buffer and ggml_cuda_pool::alloc, which is where the calls to cudaMalloc are made. It may make more sense to replace the calls to cudaMalloc there with a more generic function that takes a device id, such as cudaError_t ggml_cuda_alloc_device(void ** ptr, size_t size, int device).

@Djip007
Copy link
Contributor Author

Djip007 commented May 19, 2024

Thanks @slaren ! I have a look on is use more closely... a see 3 uses:

  • device for ggml_cuda_pool::alloc
  • buft_ctx->device for ggml_backend_cuda_buffer_type_alloc_buffer
  • id for ggml_backend_cuda_split_buffer_init_tensor

Is it the "true" device ID or the id it have on "ggml_backend_cuda" ?

If it is that simple I can try to make a PR (not now it is time to go to sleep ;) )

@slaren
Copy link
Collaborator

slaren commented May 19, 2024

It is the CUDA device id. I assume it is the same kind of device id that is used in hipMemAdvise, but I don't know that.

@Djip007
Copy link
Contributor Author

Djip007 commented May 20, 2024

https://rocm.docs.amd.com/projects/HIP/en/latest/doxygen/html/group___device.html#ga43c1e7f15925eeb762195ccb5e063eae

"hipSetDevice()" it say:

Many HIP APIs implicitly use the "default device" :

Any device memory subsequently allocated from this host thread (using hipMalloc) will be allocated on device.

So it make sens that it is the correct ID for hipMemAdvise in this case ...

👍

@Djip007
Copy link
Contributor Author

Djip007 commented May 20, 2024

may look like that...

static inline cudaError_t ggml_cuda_alloc_device(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_HIP_UMA)
    auto res = hipMallocManaged(ptr, size);
    if (res == hipSuccess) {
        return hipMemAdvise (*ptr, size, hipMemAdviseSetCoarseGrain, device);
    }
    return res;
#else
    return hipMalloc(ptr, size);
#endif
#else
    return cudaMalloc(ptr, size);
#endif
}

or is it best practice to use CUDA_CHECK and return void?

@slaren
Copy link
Collaborator

slaren commented May 20, 2024

Return the cudaError_t and let the caller handle it, it's not always ok to crash the application on an allocation failure.

Djip007 added a commit to Djip007/llama.cpp that referenced this issue May 20, 2024
add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
- get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)
@Djip007
Copy link
Contributor Author

Djip007 commented May 20, 2024

OK PR submit. Test on fc40/rocm6.0 on Ryzen 7940HS.

If any on can compare on other rocm / APU (Ryzen 6800HS / 7840U...) OS ...

My dream would be to see what it looks like with MI300A...

like report on merge request I get:

# build
make -j16 LLAMA_HIPBLAS=1 LLAMA_HIP_UMA=1 AMDGPU_TARGETS=gfx1101

# run
HSA_OVERRIDE_GFX_VERSION=11.0.1 ./main -m ~/LLM/mistral-7b-instruct-v0.2.Q8_0.gguf -ngl 999 --temp 0 -c 2048 -p "[INST] {ask resume of a long text...} [/INST]"

# before PR:
llama_print_timings:        load time =    3386,19 ms
llama_print_timings:      sample time =      12,14 ms /   609 runs   (    0,02 ms per token, 50177,14 tokens per second)
llama_print_timings: prompt eval time =   15051,81 ms /  1466 tokens (   10,27 ms per token,    97,40 tokens per second)
llama_print_timings:        eval time =  100420,23 ms /   608 runs   (  165,16 ms per token,     6,05 tokens per second)
llama_print_timings:       total time =  115547,81 ms /  2074 tokens

# after PR:
llama_print_timings:        load time =    2606,81 ms
llama_print_timings:      sample time =      12,19 ms /   609 runs   (    0,02 ms per token, 49946,69 tokens per second)
llama_print_timings: prompt eval time =    8120,08 ms /  1466 tokens (    5,54 ms per token,   180,54 tokens per second)
llama_print_timings:        eval time =   65652,40 ms /   608 runs   (  107,98 ms per token,     9,26 tokens per second)
llama_print_timings:       total time =   73841,90 ms /  2074 tokens

@Djip007
Copy link
Contributor Author

Djip007 commented May 21, 2024

Did some more test on same hardware.

But first an advise... use --no-mmap when use of LLAMA_HIP_UMA=1 it look to use memory more accurately, and we in all case need to load weights. Should we deactivate mmap (at least by default) when LLAMA_HIP_UMA is active on build?

some result:

//> Mixtral 7b Q4_K_M
// HSA_OVERRIDE_GFX_VERSION=11.0.1  ./main -m mistral-7b-instruct-v0.2.Q4_K_M.gguf -ngl 999 --no-mmap --temp 0 -c 2048 -p "[INST] ...
llama_print_timings:        load time =    4965,74 ms
llama_print_timings:      sample time =      10,01 ms /   387 runs   (    0,03 ms per token, 38653,62 tokens per second)
llama_print_timings: prompt eval time =    7884,96 ms /  1466 tokens (    5,38 ms per token,   185,92 tokens per second)
llama_print_timings:        eval time =   29657,97 ms /   386 runs   (   76,83 ms per token,    13,02 tokens per second)
llama_print_timings:       total time =   37612,63 ms /  1852 tokens
//> Mixtral 8x7b Q4_K_M
// HSA_OVERRIDE_GFX_VERSION=11.0.1  ./main -m mixtral-8x7b-instruct-v0.1.Q4_K_M.gguf -ngl 999 --no-mmap --temp 0 -c 2048 -p "[INST] ...
llama_print_timings:        load time =   12226.79 ms
llama_print_timings:      sample time =       1.03 ms /   138 runs   (    0.01 ms per token, 134502.92 tokens per second)
llama_print_timings: prompt eval time =   16152.11 ms /  1466 tokens (   11.02 ms per token,    90.76 tokens per second)
llama_print_timings:        eval time =   19030.05 ms /   137 runs   (  138.91 ms per token,     7.20 tokens per second)
llama_print_timings:       total time =   35221.27 ms /  1603 tokens
//> Mixtral 8x7b Q6_K
// HSA_OVERRIDE_GFX_VERSION=11.0.1  ./main -m mixtral-8x7b-instruct-v0.1.Q6_K.gguf -ngl 999 --no-mmap --temp 0 -c 2048 -p "[INST] ...
llama_print_timings:        load time =   30726,87 ms
llama_print_timings:      sample time =      27,06 ms /   388 runs   (    0,07 ms per token, 14339,04 tokens per second)
llama_print_timings: prompt eval time =   16200,58 ms /  1466 tokens (   11,05 ms per token,    90,49 tokens per second)
llama_print_timings:        eval time =   58623,74 ms /   387 runs   (  151,48 ms per token,     6,60 tokens per second)
llama_print_timings:       total time =   74941,65 ms /  1853 tokens
//> Meta-Llama-3-8B-Instruct.fp16.gguf
// HSA_OVERRIDE_GFX_VERSION=11.0.1  ./main -m Meta-Llama-3-8B-Instruct.fp16.gguf -ngl 999 --no-mmap --temp 0 -c 2048 -p "[INST] ...
llama_print_timings:        load time =    8016.27 ms
llama_print_timings:      sample time =       3.26 ms /   110 runs   (    0.03 ms per token, 33721.64 tokens per second)
llama_print_timings: prompt eval time =    6542.20 ms /  1328 tokens (    4.93 ms per token,   202.99 tokens per second)
llama_print_timings:        eval time =   22346.55 ms /   109 runs   (  205.01 ms per token,     4.88 tokens per second)
llama_print_timings:       total time =   28941.99 ms /  1437 tokens

@slaren
Copy link
Collaborator

slaren commented May 21, 2024

I am not sure what issue you are having with mmap, but we cannot change the behavior of llama.cpp based on the configuration of a backend.

@Djip007
Copy link
Contributor Author

Djip007 commented May 23, 2024

The "issue"... not completely sure. I will try to explain how I see it (I don't know if that's "true")

Imagine: my computer have 64Go of RAM, 4Go is reserve for the GPU (VRAM) in bios. So 60Go of RAM is usable. OS/firefox... use 8Go (an example ;) ) so I have 52Go of RAM for the model.
The model is 48Go ...

When it is load with mmap active, and UMA, each weight is read from mmap and copie in a hipMallocManaged new memory. read from mmap involves loading in RAM the tensor and after copie it in "GPU_UMA". This means the tensor is 2 times in RAM... after 52/2=26Go of load no more memory is available, and the kernel need to find un-needed RAM. look it consider the memory that is the oldest accessed ... mmap "RAM" was used a short time ago... so Firefox RAM is Transfer to swap ... with more load it finally evicts the oldest mmap "RAM"...
With no mmap the weight look like it is directly copied on "GPU_UMA" (this is what I am least sure of...) so all the 48Go of tensor have enough space to be copied.

So don't really know, but mmap+swap+UMA ... puts a lot of pressure on Linux memory management...

For me use mmap with "full" GPU offloading is a bad idee ... (with normal "cudaMalloc" on device, imagine mmap on computer with 8G of RAM and a GPU of 48Go of VRAM (I know, it's not really a good choice ;) ) but if the model fit in VRAM it may be possible to made it work...

That say, I don't know if we simply need to give some advise/good practice ... or make llama.cpp more intelligent to chose "better" strategie ... like for exemple use mmap by default only if the weight will not be copied on "local backend" ... but simple to say... not to "code"

ps: I have many more question with llama.cpp memory management (like: with UMA, tensor are in RAM. Did we have to do something to avoid copie back in RAM if an OP is execute on CPU? ...) but I need to have closer look on that and open an other issue if I can't figure how it work...

@slaren
Copy link
Collaborator

slaren commented May 23, 2024

Ok, I understand, thanks for the explanation. For the CPU and Metal backends, we have functions to create a backend buffer from a pointer. In this way, the backend can use the memory mapped file without additional copies. I think that may be possible to do this with CUDA/HIP for UMA devices using hip/cudaHostRegister. We would need to add a function similar to ggml_backend_cpu_buffer_from_ptr and ggml_backend_metal_buffer_from_ptr for the CUDA backend, and add support in llm_load_tensors.

@slaren
Copy link
Collaborator

slaren commented May 23, 2024

I have made a test implementation in 518b752. It works for me on a 3090, although slowly.

@Djip007
Copy link
Contributor Author

Djip007 commented May 23, 2024

Wonderful ... No time to test/bench now...

Add on HIP the same hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device) somewhere in ggml_backend_cuda_buffer_from_ptr may have same perf as with hipMallocManaged + hipMemAdvise on AMD APU...

Try to make some benchmark on my "gemm_test" tomorrow with this memory configuration (malloc/hipHostRegister/cudaHostGetDevicePointer and hipMemAdviseSetCoarseGrain) ... on my ryzen 7940HS !

🤞

@Djip007
Copy link
Contributor Author

Djip007 commented May 23, 2024

OK did some bench with rocblas... (10000 iteration on 1024x1024 gemm of FP16=>FP32.) / ryzen 7940HS APU.

hipMallocManaged + hipMemAdvise => ~10.6 GFlops
malloc + hipHostRegister + hipHostGetDevicePointer => ~2.8 GFlops
malloc + hipHostRegister + hipHostGetDevicePointer + hipMemAdvise => ~9.8 GFlops

Not bad. Less than 10% lose with "correct" malloc use

(https://github.com/Djip007/gemm_test/blob/main/gemm_hipblas_03.cpp#L115 for how I register the RAM.)

The test is with malloc not mmap... so next step i have to get your test implementation, change with hip API call and bench it with LLM.

@Djip007
Copy link
Contributor Author

Djip007 commented May 24, 2024

OK next: I checkout your test and change ggml_backend_cuda_buffer_from_ptr to use hip: d38ae30
what I get on benchmark 3 "config" on my 7940HX:
1- main/CPU for ref.
2- the hip_UMA PR : (build make -j16 LLAMA_HIPBLAS=1 LLAMA_HIP_UMA=1 AMDGPU_TARGETS=gfx1101 run HSA_OVERRIDE_GFX_VERSION=11.0.1 ./main -ngl 999 --no-mmap --temp 0 -c 2048 ...
3- with the hip_mmap test only "mmap on UMA" : (build make -j16 LLAMA_HIPBLAS=1 AMDGPU_TARGETS=gfx1101 run HSA_OVERRIDE_GFX_VERSION=11.0.1 ./main -ngl 999 --temp 0 -c 2048 ...)

#> mistral Q8 / mistral-7b-instruct-v0.2.Q8_0.gguf
# 1-
llama_print_timings:        load time =    1383.08 ms
llama_print_timings:      sample time =       4.02 ms /   534 runs   (    0.01 ms per token, 132835.82 tokens per second)
llama_print_timings: prompt eval time =   29762.54 ms /  1466 tokens (   20.30 ms per token,    49.26 tokens per second)
llama_print_timings:        eval time =   82369.69 ms /   533 runs   (  154.54 ms per token,     6.47 tokens per second)
llama_print_timings:       total time =  112190.80 ms /  1999 tokens

# 2-
llama_print_timings:        load time =    5470.93 ms
llama_print_timings:      sample time =       2.94 ms /   121 runs   (    0.02 ms per token, 41212.53 tokens per second)
llama_print_timings: prompt eval time =    8093.41 ms /  1466 tokens (    5.52 ms per token,   181.14 tokens per second)
llama_print_timings:        eval time =   12917.88 ms /   120 runs   (  107.65 ms per token,     9.29 tokens per second)
llama_print_timings:       total time =   21095.91 ms /  1586 tokens

# 3-
llama_print_timings:        load time =    4965,05 ms
llama_print_timings:      sample time =       7,17 ms /   367 runs   (    0,02 ms per token, 51214,07 tokens per second)
llama_print_timings: prompt eval time =    7460,36 ms /  1460 tokens (    5,11 ms per token,   195,70 tokens per second)
llama_print_timings:        eval time =   39684,78 ms /   366 runs   (  108,43 ms per token,     9,22 tokens per second)
llama_print_timings:       total time =   47213,29 ms /  1826 tokens
:0:/builddir/build/BUILD/clr-rocm-6.0.2/rocclr/device/device.cpp:321 : 0118755366 us: [pid:4209  tid:0x7f341eaf4100] Memobj map does not have ptr: 0x0
Abandon (core dumped)
#> Mixtral 8x7b Q4_K_M
# -1
llama_print_timings:        load time =     619.85 ms
llama_print_timings:      sample time =       4.04 ms /   534 runs   (    0.01 ms per token, 132014.83 tokens per second)
llama_print_timings: prompt eval time =   61530.58 ms /  1466 tokens (   41.97 ms per token,    23.83 tokens per second)
llama_print_timings:        eval time =   89358.63 ms /   533 runs   (  167.65 ms per token,     5.96 tokens per second)
llama_print_timings:       total time =  150946.38 ms /  1999 tokens

# -2
llama_print_timings:        load time =   12226.79 ms
llama_print_timings:      sample time =       1.03 ms /   138 runs   (    0.01 ms per token, 134502.92 tokens per second)
llama_print_timings: prompt eval time =   16152.11 ms /  1466 tokens (   11.02 ms per token,    90.76 tokens per second)
llama_print_timings:        eval time =   19030.05 ms /   137 runs   (  138.91 ms per token,     7.20 tokens per second)
llama_print_timings:       total time =   35221.27 ms /  1603 tokens

# -3
llama_print_timings:        load time =   15662,88 ms
llama_print_timings:      sample time =      11,54 ms /   534 runs   (    0,02 ms per token, 46293,89 tokens per second)
llama_print_timings: prompt eval time =   15500,37 ms /  1460 tokens (   10,62 ms per token,    94,19 tokens per second)
llama_print_timings:        eval time =   79876,90 ms /   533 runs   (  149,86 ms per token,     6,67 tokens per second)
llama_print_timings:       total time =   95426,05 ms /  1993 tokens
:0:/builddir/build/BUILD/clr-rocm-6.0.2/rocclr/device/device.cpp:321 : 1919801141 us: [pid:11542 tid:0x7f562acb2100] Memobj map does not have ptr: 0x0
Abandon (core dumped)
#> Mixtral 8x7b Q6_K
# -1
llama_print_timings:        load time =   42931,53 ms
llama_print_timings:      sample time =       8,40 ms /   422 runs   (    0,02 ms per token, 50262,03 tokens per second)
llama_print_timings: prompt eval time =   67656,85 ms /  1466 tokens (   46,15 ms per token,    21,67 tokens per second)
llama_print_timings:        eval time =   83567,87 ms /   421 runs   (  198,50 ms per token,     5,04 tokens per second)
llama_print_timings:       total time =  151276,59 ms /  1887 tokens

# -2
llama_print_timings:        load time =   30726,87 ms
llama_print_timings:      sample time =      27,06 ms /   388 runs   (    0,07 ms per token, 14339,04 tokens per second)
llama_print_timings: prompt eval time =   16200,58 ms /  1466 tokens (   11,05 ms per token,    90,49 tokens per second)
llama_print_timings:        eval time =   58623,74 ms /   387 runs   (  151,48 ms per token,     6,60 tokens per second)
llama_print_timings:       total time =   74941,65 ms /  1853 tokens

# -3
llama_print_timings:        load time =   20793,13 ms
llama_print_timings:      sample time =      34,36 ms /   548 runs   (    0,06 ms per token, 15949,24 tokens per second)
llama_print_timings: prompt eval time =   15704,29 ms /  1460 tokens (   10,76 ms per token,    92,97 tokens per second)
llama_print_timings:        eval time =   89642,54 ms /   547 runs   (  163,88 ms per token,     6,10 tokens per second)
llama_print_timings:       total time =  105467,46 ms /  2007 tokens
:0:/builddir/build/BUILD/clr-rocm-6.0.2/rocclr/device/device.cpp:321 : 2503846519 us: [pid:12712 tid:0x7fcf18665100] Memobj map does not have ptr: 0x0
Abandon (core dumped)

@Djip007
Copy link
Contributor Author

Djip007 commented May 24, 2024

some more "elements/comment":

  • I use hipMemAdvise but not really speedup with it. may be how mmap is config make hip already "cachable"
  • look like some quantized model do not support mmap ... in this cas It try tu use hipMalloc and failed with "out of memory" (I can't define more than 4Go bios...)
  • As you can see there is a crache after prompte timing (and before exit...) but don't matter for benchmark.
  • fast init with mmap now in all cas .

I have 64Go of RAM... so if you want I test some other model.

@slaren
Copy link
Collaborator

slaren commented May 24, 2024

The results look good, it seems that it doesn't decrease performance, which is the main goal. The crash at the end is expected, it's because the memory is not released with the correct function.

The main advantage that I would expect from this would be improved load time, especially when loading the model repeatedly since it would in the system cache, but I am not sure if the difference is big enough to put a lot of effort into this. There may be some issues with this implementation, it is a bit hacky, and it removes the padding that is usually added to quantized tensors, and that may cause issues with the MMQ kernels.

I don't have any devices with UMA, so I cannot really finish this, but if this is something that you find interesting feel free to use the code in any way you want, and open a PR if you think this may be useful.

look like some quantized model do not support mmap

Old mixtral models do not support mmap, they need to be converted to gguf again.

@Djip007
Copy link
Contributor Author

Djip007 commented May 24, 2024

I don't have any devices with UMA, so I cannot really finish this, but if this is something that you find interesting feel free to use the code in any way you want, and open a PR if you think this may be useful.

I must go to sleep... But I think I make people dream of that.

To release something I may need some help to figure where I need to "release" the memory (I think I know what to do but not where ...)

The biggest question will be when/how to activate this feature. But yes I may open a PR for this.
Thank for your help!!!

The main advantage that I would expect from this would be improved load time, especially when loading the model repeatedly since it would in the system cache [...]

Not sure how to measure that... but yes if it work it can be useful for certain scenarios.

@Djip007
Copy link
Contributor Author

Djip007 commented May 24, 2024

#> Meta-Llama-3-8B-Instruct.fp16.gguf
// - gfx1103 / rocblas / HSA_OVERRIDE_GFX_VERSION=11.0.1 + mmap
# 1 load
llama_print_timings:        load time =    9027,84 ms
llama_print_timings:      sample time =      49,96 ms /   623 runs   (    0,08 ms per token, 12470,97 tokens per second)
llama_print_timings: prompt eval time =    5926,34 ms /  1348 tokens (    4,40 ms per token,   227,46 tokens per second)
llama_print_timings:        eval time =  138236,22 ms /   622 runs   (  222,24 ms per token,     4,50 tokens per second)
llama_print_timings:       total time =  144354,97 ms /  1970 tokens

# next load
llama_print_timings:        load time =    2652,25 ms
llama_print_timings:      sample time =      49,33 ms /   623 runs   (    0,08 ms per token, 12630,51 tokens per second)
llama_print_timings: prompt eval time =    5906,83 ms /  1348 tokens (    4,38 ms per token,   228,21 tokens per second)
llama_print_timings:        eval time =  135707,96 ms /   622 runs   (  218,18 ms per token,     4,58 tokens per second)
llama_print_timings:       total time =  141791,26 ms /  1970 tokens

... the loading time is already given...

9s => 2.6 s ... yes it have some gain 😉

@Djip007
Copy link
Contributor Author

Djip007 commented May 26, 2024

I was thinking of doing a small update

static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_HIP_UMA)
    // try normal alloc
    auto res = hipMalloc(ptr, size);
    if (res == hipErrorOutOfMemory) {
        // Not enough space on VRAM => try on UMA
        GGML_CUDA_LOG_INFO("  Device %d: can not alloc %d MB on VRAM try alloc on GTT\n", device, (uint32_t)(size / 1024 / 1024));
        res = hipMallocManaged(ptr, size);
        if (res == hipSuccess) {
            // Config the memory for best speed (It's not supposed to fail)
            CUDA_CHECK(hipMemAdvise(*ptr, size, hipMemAdviseSetCoarseGrain, device));
        }
    }
    return res;
#else
    return hipMalloc(ptr, size);
#endif
#else
    return cudaMalloc(ptr, size);
#endif
}

Alloc device VRAM if possible and only use UMA (hipMallocManaged) if not working. This way it will allow you to activate it without impacting perf on GPU with enough VRAM.

What do you think @slaren ?

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

The problem is that unless you have a system where the GPU and the CPU share the same memory, using shared system memory in the GPU is too slow to be useful. It is generally better to use the CPU and offload fewer layers. I also think that using managed memory is completely unnecessary and I have no idea who is using this, so I can only assume that this is just another weirdness of AMD GPUs.

@Djip007
Copy link
Contributor Author

Djip007 commented May 26, 2024

Using shared memory is use on all APU: APPLE / INTEL / QUALCOMM / AMD and even NVIDIA on Jetson. Then managed memory make it possible to share VRAM from GPU that may have hight exchange (Infinity fabric in case of AMD.). So for example on MI300A it may be possible to share all of the 4x128 Go of RAM on all core+gpu. But that an other story.

Yes it is best on dGPU to offload fewer layers, that is why I change for default alloc on VRAM. This way if the user chooses the right number of layers to offload it shouldn't change anything to activate this option.
The idea is not to allow the use of managed memory on the dGPUs, but to be able to activate this function without impacting these GPUs if they are "well" configured.

By the way this change is only for AMD GPU. Do you know what happens on Nvidia Jeston did cudaMalloc use all RAM / GTT?

And yes I really need to find time for "mmap" it may be better for low VRAM APU than this UMA in this case.. but harder to implement.

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

I just find this very confusing, because the point of managed memory is that it can be accessed both from the CPU and the GPU automaticallty. The consequence of this on an iGPU like Jetson is that this memory will be uncached to ensure coherency, so it is still desirable to allocate memory exclusively for the iGPU using cudaMalloc to allow cached access. The memory that the ggml CUDA backend allocates with cudaMalloc is never accessed from the CPU, so it seems completely pointless to allocate it as managed. But I don't know the kirks of AMD GPUs so I am probably missing something.

@Djip007
Copy link
Contributor Author

Djip007 commented May 26, 2024

What I understand.

Until next kernel (6.10) hipMalloc will only use reserved VRAM for customer APU. And on most laptop it is not possible to reserve big RAM on bios. For me the use here of hipMallocManaged is only to have use of more RAM. On APU it is possible (has you see) to use other function for same thing, I have same speed with hipHostMalloc and even hipHostRegister with correct config on my APU. I don't thing in our case / GPU it make difference to use hipMallocManaged (it may on MI250 / MI300 ... but can test 😥 .

I still have tests to do (and to come). like: is hipMallocManaged limited by GTT memory size?

for me we may have write it like that with same result on APU (but may be not on high end MI250 GPU...)

static inline cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
#if defined(GGML_USE_HIPBLAS)
#if defined(GGML_HIP_UMA)
    return hipHostMalloc(ptr, size, hipHostMallocNonCoherent) + hipHostGetDevicePointer(...);
#else
    return hipMalloc(ptr, size);
#endif
#else
    return cudaMalloc(ptr, size);
#endif
}

ps: Do you know how cudaMalloc work on Jetson, can it use all RAM on is it limited to GTT kernel size or even VRAM? Is there no equivalent to hipMemAdvise with Cuda?

@slaren
Copy link
Collaborator

slaren commented May 26, 2024

There is some documentation about the Jetson memory architecture here: https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/index.html

My understanding is that the Jetson is connected directly to the system memory and there is no GTT, it can access all the RAM directly, as long as it is pinned by the OS. But I am not an expert about this.

@Djip007
Copy link
Contributor Author

Djip007 commented May 26, 2024

https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__MEMORY.html#group__CUDART__MEMORY_1ge37112fc1ac88d0f6bab7a945e48760a

heu... yes really not clear...

I have a look a you other post.

What is not clear for me (on AMD GPU) is if pinned memory is limited by GTT or not.

@Djip007
Copy link
Contributor Author

Djip007 commented May 26, 2024

There is some documentation about the Jetson memory architecture here: https://docs.nvidia.com/cuda/cuda-for-tegra-appnote/index.html

Whow... need more time and test to "really" understand all that... Look so close to hardware that I can figure that AMD may have a slight difference...

I may need to use

int managed_memory = 0;
HIPCHECK(hipDeviceGetAttribute(&managed_memory,
  hipDeviceAttributeManagedMemory,p_gpuDevice));

to control if I can use it... or figure when it can failed to know if it matter...
Look the same on NVIDIA.

@Djip007
Copy link
Contributor Author

Djip007 commented May 27, 2024

I made some test...

hipMallocManaged with AMD is not limited with GTT size. So even after kernel 6.10 and use of GTT for hip malloc (if I understand correctly) It can be interesting to use it as a last resort for AMD APU.

slaren added a commit that referenced this issue May 27, 2024
* update HIP_UMA #7399

add use of hipMemAdviseSetCoarseGrain when LLAMA_HIP_UMA is enable.
- get x2 on prompte eval and x1.5 on token gen with rocm6.0 on ryzen 7940HX iGPU (780M/gfx1103)

* simplify code, more consistent style

---------

Co-authored-by: slaren <slarengh@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants