-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
LLM evaluator setting any repetitionPenalty crashes the program #71
Comments
I tried this on |
OK, I can reproduce it. I won't check it in yet because of the quantization issues in #53 -- it will be included there. If you want to try it locally, here is the change: diff --git a/Libraries/LLM/Evaluate.swift b/Libraries/LLM/Evaluate.swift
index 6c85558..89212a7 100644
--- a/Libraries/LLM/Evaluate.swift
+++ b/Libraries/LLM/Evaluate.swift
@@ -31,7 +31,7 @@ private func applyRepetitionPenalty(
) -> MLXArray {
if repetitionContext.shape[0] > 0 {
let indices = repetitionContext
- var selectedLogits = take(logits, indices, axis: -1).squeezed(axis: 0)
+ var selectedLogits = logits[0..., indices]
selectedLogits = MLX.where(
selectedLogits .< 0, selectedLogits * penalty, selectedLogits / penalty)
@@ -100,7 +100,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
if prompt.shape[0] <= parameters.repetitionContextSize {
self.repetitionContext = prompt
} else {
- self.repetitionContext = prompt[-parameters.repetitionContextSize ... -1]
+ self.repetitionContext = prompt[(-parameters.repetitionContextSize)...]
}
} else {
self.repetitionContext = []
@@ -120,9 +120,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = sample(logits: logits, temp: parameters.temperature, topP: parameters.topP)
// append the current token to the context and check repetitionPenalty context see if need to remove the first token
if parameters.repetitionContextSize > 1 {
- repetitionContext = concatenated([repetitionContext, y], axis: 0)
if repetitionContext.shape[0] > parameters.repetitionContextSize {
- repetitionContext = repetitionContext[1...]
+ repetitionContext = repetitionContext[(-parameters.repetitionContextSize)...]
}
} I just switched it to use the full array indexing and made it conform to the python code. I don't know if this is a bug in the mlx core code or a bug in the calling code -- certainly the calling code requires some changes and I don't think it is logically the same. |
#76 should fi this |
I'm adding a repetitionPenalty to the GenerateParameters constructor. Regardless what values I set (I tried 0.5, 1, 1.2, 1.5), it crashes the program immediate as the evaluator runs. I was testing various Qwen1.5 models. Error message I got is
-[MTLDebugComputeCommandEncoder dispatchThreads:threadsPerThreadgroup:]:1441: failed assertion '(threadsPerGrid.width(0) * threadsPerGrid.y(1) * threadsPerGrid.depth(0))(0) must not be 0.'
The text was updated successfully, but these errors were encountered: