Cache (i.e. past_key_values) handling seems very problematic
In modeling_ernie4_5_vl.py lines 2882-2883, the past_key_values are set as:
if past_key_values is None:
past_key_values = tuple([None] * len(self.layers))
But lines 2911-2913 checks with:
past_key_value = (
past_key_values[idx] if past_key_values is not None else None # maybe it has to check past_key_values[idx] if not None?
)
And then running the model with use_cache enabled (use_cache=False works without a problem) gives:
Traceback (most recent call last):
File "/home/ml-team/ml-proj/docs/ernie_reproduce_error.py", line 63, in <module>
generated_ids = model.generate(
^^^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/transformers/generation/utils.py", line 2629, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/transformers/generation/utils.py", line 3610, in _sample
outputs = self(**model_inputs, return_dict=True)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.cache/huggingface/modules/transformers_modules/ERNIE-4.5-VL-28B-A3B-PT/modeling_ernie4_5_vl.py", line 4169, in forward
outputs = self.model(
^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1751, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.pyenv/versions/ml-proj/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1762, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/ml-team/.cache/huggingface/modules/transformers_modules/ERNIE-4.5-VL-28B-A3B-PT/modeling_ernie4_5_vl.py", line 2888, in forward
cache_length = past_key_values[0][0].shape[1]
^^^^^^^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'shape'
Was this code tested?
Ok I think the following lines (2885-2889) are redundant:
seq_length_with_past = seq_length
cache_length = 0
if past_key_values[0] is not None:
cache_length = past_key_values[0][0].shape[1]
seq_length_with_past += cache_length
And more importantly, lines (2911-2913) should be changed in following way (bug fix):
# REMOVE START
past_key_value = (
past_key_values[idx] if past_key_values is not None else None
)
# REMOVE END
# REPLACE START
past_key_value = past_key_values[idx]
# REPLACE END
Ok I think the following lines (2885-2889) are redundant:
seq_length_with_past = seq_length cache_length = 0 if past_key_values[0] is not None: cache_length = past_key_values[0][0].shape[1] seq_length_with_past += cache_lengthAnd more importantly, lines (2911-2913) should be changed in following way (bug fix):
# REMOVE START past_key_value = ( past_key_values[idx] if past_key_values is not None else None ) # REMOVE END # REPLACE START past_key_value = past_key_values[idx] # REPLACE END
I meet the same issue on the model,
with above change, the model can work ?
Sorry I forgot that I also edited line 2882 as follows:
if past_key_values is None or past_key_values[0][0] is None: # It was if past_key_values is None
So I think there is some problematic cache handling because past_key_values[0][0] is None is actually necessary. For now, the above changes make it runnable but I'll come back to it if it proves to cause significant performance issues (vllm and stuff).