Fix the function of chat to support the gradio demo
Browse files- modeling_mplugowl3.py +20 -31
modeling_mplugowl3.py
CHANGED
|
@@ -142,7 +142,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 142 |
media_offset=None,
|
| 143 |
attention_mask=None,
|
| 144 |
tokenizer=None,
|
| 145 |
-
return_vision_hidden_states=False,
|
| 146 |
stream=False,
|
| 147 |
decode_text=False,
|
| 148 |
**kwargs
|
|
@@ -156,9 +155,6 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 156 |
result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
|
| 157 |
else:
|
| 158 |
result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
|
| 159 |
-
|
| 160 |
-
if return_vision_hidden_states:
|
| 161 |
-
return result, image_embeds
|
| 162 |
|
| 163 |
return result
|
| 164 |
|
|
@@ -166,10 +162,9 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 166 |
self,
|
| 167 |
images,
|
| 168 |
videos,
|
| 169 |
-
|
| 170 |
tokenizer,
|
| 171 |
processor=None,
|
| 172 |
-
vision_hidden_states=None,
|
| 173 |
max_new_tokens=2048,
|
| 174 |
min_new_tokens=0,
|
| 175 |
sampling=True,
|
|
@@ -180,21 +175,23 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 180 |
use_image_id=None,
|
| 181 |
**kwargs
|
| 182 |
):
|
| 183 |
-
print(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
if processor is None:
|
| 185 |
if self.processor is None:
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
inputs
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
max_length=max_inp_length
|
| 197 |
-
).to(self.device)
|
| 198 |
|
| 199 |
if sampling:
|
| 200 |
generation_config = {
|
|
@@ -202,12 +199,12 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 202 |
"top_k": 100,
|
| 203 |
"temperature": 0.7,
|
| 204 |
"do_sample": True,
|
| 205 |
-
"repetition_penalty": 1.05
|
| 206 |
}
|
| 207 |
else:
|
| 208 |
generation_config = {
|
| 209 |
"num_beams": 3,
|
| 210 |
-
"repetition_penalty": 1.2,
|
| 211 |
}
|
| 212 |
|
| 213 |
if min_new_tokens > 0:
|
|
@@ -216,14 +213,10 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 216 |
generation_config.update(
|
| 217 |
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
|
| 218 |
)
|
| 219 |
-
|
| 220 |
-
inputs.pop("image_sizes")
|
| 221 |
with torch.inference_mode():
|
| 222 |
res = self.generate(
|
| 223 |
**inputs,
|
| 224 |
-
tokenizer=tokenizer,
|
| 225 |
-
max_new_tokens=max_new_tokens,
|
| 226 |
-
vision_hidden_states=vision_hidden_states,
|
| 227 |
stream=stream,
|
| 228 |
decode_text=True,
|
| 229 |
**generation_config
|
|
@@ -238,9 +231,5 @@ class mPLUGOwl3Model(mPLUGOwl3PreTrainedModel):
|
|
| 238 |
return stream_gen()
|
| 239 |
|
| 240 |
else:
|
| 241 |
-
|
| 242 |
-
answer = res
|
| 243 |
-
else:
|
| 244 |
-
answer = res[0]
|
| 245 |
return answer
|
| 246 |
-
|
|
|
|
| 142 |
media_offset=None,
|
| 143 |
attention_mask=None,
|
| 144 |
tokenizer=None,
|
|
|
|
| 145 |
stream=False,
|
| 146 |
decode_text=False,
|
| 147 |
**kwargs
|
|
|
|
| 155 |
result = self._decode_stream(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, **kwargs)
|
| 156 |
else:
|
| 157 |
result = self._decode(input_ids=input_ids, image_embeds=image_embeds, media_offset=media_offset, tokenizer=tokenizer, attention_mask=attention_mask, decode_text=decode_text, **kwargs)
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
return result
|
| 160 |
|
|
|
|
| 162 |
self,
|
| 163 |
images,
|
| 164 |
videos,
|
| 165 |
+
messages,
|
| 166 |
tokenizer,
|
| 167 |
processor=None,
|
|
|
|
| 168 |
max_new_tokens=2048,
|
| 169 |
min_new_tokens=0,
|
| 170 |
sampling=True,
|
|
|
|
| 175 |
use_image_id=None,
|
| 176 |
**kwargs
|
| 177 |
):
|
| 178 |
+
print(messages)
|
| 179 |
+
if len(images)>1:
|
| 180 |
+
cut_flag=False
|
| 181 |
+
else:
|
| 182 |
+
cut_flag=True
|
| 183 |
if processor is None:
|
| 184 |
if self.processor is None:
|
| 185 |
+
processor = self.init_processor(tokenizer)
|
| 186 |
+
else:
|
| 187 |
+
processor = self.processor
|
| 188 |
+
inputs = processor(messages, images=images, videos=videos, cut_enable=cut_flag)
|
| 189 |
+
inputs.to('cuda')
|
| 190 |
+
inputs.update({
|
| 191 |
+
'tokenizer': tokenizer,
|
| 192 |
+
'max_new_tokens': max_new_tokens,
|
| 193 |
+
# 'stream':True,
|
| 194 |
+
})
|
|
|
|
|
|
|
| 195 |
|
| 196 |
if sampling:
|
| 197 |
generation_config = {
|
|
|
|
| 199 |
"top_k": 100,
|
| 200 |
"temperature": 0.7,
|
| 201 |
"do_sample": True,
|
| 202 |
+
# "repetition_penalty": 1.05
|
| 203 |
}
|
| 204 |
else:
|
| 205 |
generation_config = {
|
| 206 |
"num_beams": 3,
|
| 207 |
+
# "repetition_penalty": 1.2,
|
| 208 |
}
|
| 209 |
|
| 210 |
if min_new_tokens > 0:
|
|
|
|
| 213 |
generation_config.update(
|
| 214 |
(k, kwargs[k]) for k in generation_config.keys() & kwargs.keys()
|
| 215 |
)
|
| 216 |
+
print(inputs)
|
|
|
|
| 217 |
with torch.inference_mode():
|
| 218 |
res = self.generate(
|
| 219 |
**inputs,
|
|
|
|
|
|
|
|
|
|
| 220 |
stream=stream,
|
| 221 |
decode_text=True,
|
| 222 |
**generation_config
|
|
|
|
| 231 |
return stream_gen()
|
| 232 |
|
| 233 |
else:
|
| 234 |
+
answer = res[0]
|
|
|
|
|
|
|
|
|
|
| 235 |
return answer
|
|
|