alexrs-cohere commited on
Commit
fb7858e
·
1 Parent(s): 7045760

Command A Reasoning

Browse files
Files changed (5) hide show
  1. README.md +1 -1
  2. app.py +161 -120
  3. pyproject.toml +1 -1
  4. requirements.txt +10 -10
  5. uv.lock +1 -1
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: Command A Vision
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
 
1
  ---
2
+ title: Command A Reasoning
3
  emoji: ⚡
4
  colorFrom: red
5
  colorTo: purple
app.py CHANGED
@@ -1,175 +1,216 @@
1
  import os
2
- import base64
3
  from collections.abc import Iterator
4
 
5
  import gradio as gr
 
6
  from cohere import ClientV2
 
7
 
8
- model_id = "command-a-vision-07-2025"
9
 
10
  # Initialize Cohere client
11
  api_key = os.getenv("COHERE_API_KEY")
12
  if not api_key:
13
  raise ValueError("COHERE_API_KEY environment variable is required")
14
- client = ClientV2(api_key=api_key, client_name="hf-command-a-vision-07-2025")
15
-
16
- IMAGE_FILE_TYPES = (".jpg", ".jpeg", ".png", ".webp")
17
-
18
- def count_files_in_new_message(paths: list[str]) -> int:
19
- image_count = 0
20
- for path in paths:
21
- if path.endswith(IMAGE_FILE_TYPES):
22
- image_count += 1
23
- return image_count
24
-
25
-
26
- def validate_media_constraints(message: dict) -> bool:
27
- image_count = count_files_in_new_message(message["files"])
28
- if image_count > 10:
29
- gr.Warning("Maximum 10 images are supported.")
30
- return False
31
- return True
32
-
33
-
34
- def encode_image_to_base64(image_path: str) -> str:
35
- """Encode an image file to base64 data URL format."""
36
- with open(image_path, "rb") as image_file:
37
- encoded_string = base64.b64encode(image_file.read()).decode('utf-8')
38
- # Determine file extension for MIME type
39
- if image_path.lower().endswith('.png'):
40
- mime_type = "image/png"
41
- elif image_path.lower().endswith('.jpg') or image_path.lower().endswith('.jpeg'):
42
- mime_type = "image/jpeg"
43
- elif image_path.lower().endswith('.webp'):
44
- mime_type = "image/webp"
45
- else:
46
- mime_type = "image/jpeg" # default
47
- return f"data:{mime_type};base64,{encoded_string}"
48
-
49
-
50
- def generate(message: dict, history: list[dict], max_new_tokens: int = 512) -> Iterator[str]:
51
- if not validate_media_constraints(message):
52
- yield ""
53
- return
54
 
55
- # Build messages for Cohere API
56
- messages = []
57
-
58
- # Add conversation history
59
- for item in history:
60
- if item["role"] == "assistant":
61
- messages.append({"role": "assistant", "content": item["content"]})
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  else:
63
- content = item["content"]
64
- if isinstance(content, str):
65
- messages.append({"role": "user", "content": [{"type": "text", "text": content}]})
66
- else:
67
- filepath = content[0]
68
- # For file-only messages, don't include empty text
69
- messages.append({
70
- "role": "user",
71
- "content": [
72
- {"type": "image_url", "image_url": {"url": encode_image_to_base64(filepath)}}
73
- ]
74
- })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  # Add current message
77
- current_content = []
78
- if message["text"]:
79
- current_content.append({"type": "text", "text": message["text"]})
80
-
81
- for file_path in message["files"]:
82
- current_content.append({
83
- "type": "image_url",
84
- "image_url": {"url": encode_image_to_base64(file_path)}
85
- })
86
-
87
- # Only add the message if there's content
88
- if current_content:
89
- messages.append({"role": "user", "content": current_content})
90
 
91
  try:
 
 
 
 
 
92
  # Call Cohere API using the correct event type and delta access
93
  response = client.chat_stream(
94
  model=model_id,
95
  messages=messages,
96
  temperature=0.3,
97
- max_tokens=max_new_tokens,
98
  )
99
 
100
- output = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  for event in response:
102
  if getattr(event, "type", None) == "content-delta":
103
- # event.delta.message.content.text is the streamed text
104
- text = getattr(event.delta.message.content, "text", "")
105
- output += text
106
- yield output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  except Exception as e:
109
  gr.Warning(f"Error calling Cohere API: {str(e)}")
110
- yield ""
111
 
112
 
113
  examples = [
114
  [
115
- {
116
- "text": "Write a COBOL function to reverse a string",
117
- "files": [],
118
- }
119
  ],
120
  [
121
- {
122
- "text": "Como sair de um helicóptero que caiu na água?",
123
- "files": [],
124
- }
125
  ],
126
  [
127
- {
128
- "text": "What is the total amount of the invoice with and without tax?",
129
- "files": ["assets/invoice-1.jpg"],
130
- }
131
  ],
132
  [
133
- {
134
- "text": "¿Contra qué modelo gana más Aya Vision 8B?",
135
- "files": ["assets/aya-vision-win-rates.png"],
136
- }
137
  ],
138
  [
139
- {
140
- "text": "Erläutern Sie die Ergebnisse in der Tabelle",
141
- "files": ["assets/command-a-longbech-v2.png"],
142
- }
143
  ],
144
  [
145
- {
146
- "text": "Explique la théorie de la relativité en français",
147
- "files": [],
148
- }
149
  ],
150
-
151
-
152
  ]
153
 
154
  demo = gr.ChatInterface(
155
  fn=generate,
156
  type="messages",
157
- textbox=gr.MultimodalTextbox(
158
- file_types=list(IMAGE_FILE_TYPES),
159
- file_count="multiple",
160
- autofocus=True,
161
- ),
162
- multimodal=True,
163
- additional_inputs=[
164
- gr.Slider(label="Max New Tokens", minimum=100, maximum=2000, step=10, value=700),
165
- ],
166
- stop_btn=False,
167
- title="Command A Vision",
168
  examples=examples,
169
- run_examples_on_click=False,
170
- cache_examples=False,
171
  css_paths="style.css",
172
  delete_cache=(1800, 1800),
 
 
 
 
173
  )
174
 
175
  if __name__ == "__main__":
 
1
  import os
 
2
  from collections.abc import Iterator
3
 
4
  import gradio as gr
5
+ from gradio import ChatMessage
6
  from cohere import ClientV2
7
+ from cohere.core import RequestOptions
8
 
9
+ model_id = "command-a-reasoning-08-2025"
10
 
11
  # Initialize Cohere client
12
  api_key = os.getenv("COHERE_API_KEY")
13
  if not api_key:
14
  raise ValueError("COHERE_API_KEY environment variable is required")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ client = ClientV2(api_key=api_key, client_name="hf-command-a-reasoning-08-2025")
17
+
18
+ def format_chat_history(messages: list) -> list:
19
+ """
20
+ Formats the chat history into a structure Cohere can understand
21
+ """
22
+ formatted_history = []
23
+ for message in messages:
24
+ # Handle both ChatMessage objects and regular dictionaries
25
+ if hasattr(message, "metadata") and message.metadata:
26
+ # Skip thinking messages (messages with metadata)
27
+ continue
28
+
29
+ # Extract role and content safely
30
+ if hasattr(message, "role"):
31
+ role = message.role
32
+ content = message.content
33
+ elif isinstance(message, dict):
34
+ role = message.get("role")
35
+ content = message.get("content")
36
  else:
37
+ continue
38
+
39
+ if role and content:
40
+ # Ensure content is a string to prevent validation issues
41
+ if content is None:
42
+ content = ""
43
+ elif not isinstance(content, str):
44
+ content = str(content)
45
+
46
+ formatted_history.append({
47
+ "role": role,
48
+ "content": content
49
+ })
50
+ return formatted_history
51
+
52
+ def generate(message: str, history: list, thinking_budget: int) -> Iterator[list]:
53
+ # Create a clean working copy of the history (excluding thinking messages)
54
+ working_history = []
55
+ for msg in history:
56
+ # Skip thinking messages (messages with metadata)
57
+ if hasattr(msg, "metadata") and msg.metadata:
58
+ continue
59
+ working_history.append(msg)
60
+
61
+ # Format chat history for Cohere API (exclude thinking messages)
62
+ messages = format_chat_history(working_history)
63
 
64
  # Add current message
65
+ if message:
66
+ messages.append({"role": "user", "content": message})
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  try:
69
+ # Set thinking type based on thinking_budget
70
+ if thinking_budget == 0:
71
+ thinking_param = {"type": "disabled"}
72
+ else:
73
+ thinking_param = {"type": "enabled", "token_budget": thinking_budget}
74
  # Call Cohere API using the correct event type and delta access
75
  response = client.chat_stream(
76
  model=model_id,
77
  messages=messages,
78
  temperature=0.3,
79
+ request_options=RequestOptions(additional_body_parameters={"thinking": thinking_param})
80
  )
81
 
82
+ # Initialize buffers
83
+ thought_buffer = ""
84
+ response_buffer = ""
85
+ thinking_complete = False
86
+
87
+ # Start with just the new assistant messages for this interaction
88
+ current_interaction = [
89
+ ChatMessage(
90
+ role="assistant",
91
+ content="",
92
+ metadata={"title": "🧠 Thinking..."}
93
+ )
94
+ ]
95
+
96
  for event in response:
97
  if getattr(event, "type", None) == "content-delta":
98
+ delta = event.delta
99
+
100
+ if hasattr(delta, 'message'):
101
+ message = delta.message
102
+
103
+ if hasattr(message, 'content'):
104
+ content = message.content
105
+
106
+ # Check for thinking tokens first
107
+ thinking_text = getattr(content, 'thinking', None)
108
+ if thinking_text:
109
+ thought_buffer += thinking_text
110
+ # Update thinking message with metadata
111
+ current_interaction[0] = ChatMessage(
112
+ role="assistant",
113
+ content=thought_buffer,
114
+ metadata={"title": "🧠 Thinking..."}
115
+ )
116
+ # Yield only the current interaction, but ensure proper formatting
117
+ yield [
118
+ {
119
+ "role": msg.role,
120
+ "content": msg.content,
121
+ "metadata": getattr(msg, "metadata", None)
122
+ } for msg in current_interaction
123
+ ]
124
+ continue
125
+
126
+ # Check for regular text tokens
127
+ text = getattr(content, 'text', None)
128
+ if text:
129
+ # Ensure text is a string
130
+ if text is None:
131
+ text = ""
132
+ elif not isinstance(text, str):
133
+ text = str(text)
134
+
135
+ # If we haven't completed thinking yet, this might be the start of the response
136
+ if not thinking_complete and thought_buffer:
137
+ thinking_complete = True
138
+ # Add response message below thinking
139
+ current_interaction.append(
140
+ ChatMessage(
141
+ role="assistant",
142
+ content=""
143
+ )
144
+ )
145
+
146
+ if thinking_complete:
147
+ # if thinking is complete, we collapse the thinking message
148
+ current_interaction[0] = ChatMessage(
149
+ role="assistant",
150
+ content=thought_buffer,
151
+ metadata={"title": "🧠 Thoughts", "status": "done"}
152
+ )
153
+
154
+ response_buffer += text
155
+ # Update response message
156
+ current_interaction[-1] = ChatMessage(
157
+ role="assistant",
158
+ content=response_buffer
159
+ )
160
+ # Yield only the current interaction, but ensure proper formatting
161
+ yield [
162
+ {
163
+ "role": msg.role,
164
+ "content": msg.content,
165
+ "metadata": getattr(msg, "metadata", None)
166
+ } for msg in current_interaction
167
+ ]
168
+
169
+ # Final cleanup: ensure the final response is clean
170
+ if thought_buffer and response_buffer:
171
+ # Keep both thinking and response messages in the final history
172
+ # The thinking message will be preserved with its metadata
173
+ pass
174
 
175
  except Exception as e:
176
  gr.Warning(f"Error calling Cohere API: {str(e)}")
177
+ yield []
178
 
179
 
180
  examples = [
181
  [
182
+ "Write a COBOL function to reverse a string"
 
 
 
183
  ],
184
  [
185
+ "Como sair de um helicóptero que caiu na água?"
 
 
 
186
  ],
187
  [
188
+ "What is the best way to learn machine learning?"
 
 
 
189
  ],
190
  [
191
+ "Explain quantum computing in simple terms"
 
 
 
192
  ],
193
  [
194
+ "How do I implement a binary search tree?"
 
 
 
195
  ],
196
  [
197
+ "Explique la théorie de la relativité en français"
 
 
 
198
  ],
 
 
199
  ]
200
 
201
  demo = gr.ChatInterface(
202
  fn=generate,
203
  type="messages",
204
+ autofocus=True,
205
+ title="Command A Reasoning",
 
 
 
 
 
 
 
 
 
206
  examples=examples,
207
+ run_examples_on_click=True,
 
208
  css_paths="style.css",
209
  delete_cache=(1800, 1800),
210
+ cache_examples=False,
211
+ additional_inputs=[
212
+ gr.Slider(label="Thinking Budget", minimum=0, maximum=2000, step=10, value=500),
213
+ ],
214
  )
215
 
216
  if __name__ == "__main__":
pyproject.toml CHANGED
@@ -1,5 +1,5 @@
1
  [project]
2
- name = "command-a-vision-07-2025"
3
  version = "0.1.0"
4
  description = ""
5
  readme = "README.md"
 
1
  [project]
2
+ name = "command-a-reasoning-07-2025"
3
  version = "0.1.0"
4
  description = ""
5
  readme = "README.md"
requirements.txt CHANGED
@@ -1,7 +1,7 @@
1
  # This file was autogenerated by uv via the following command:
2
  # uv pip compile pyproject.toml -o requirements.txt
3
  accelerate==1.8.1
4
- # via command-a-vision-07-2025 (pyproject.toml)
5
  aiofiles==24.1.0
6
  # via gradio
7
  annotated-types==0.7.0
@@ -14,7 +14,7 @@ anyio==4.9.0
14
  audioread==3.0.1
15
  # via librosa
16
  av==14.4.0
17
- # via command-a-vision-07-2025 (pyproject.toml)
18
  certifi==2025.6.15
19
  # via
20
  # httpcore
@@ -49,7 +49,7 @@ fsspec==2025.5.1
49
  # torch
50
  gradio==5.34.2
51
  # via
52
- # command-a-vision-07-2025 (pyproject.toml)
53
  # spaces
54
  gradio-client==1.10.3
55
  # via gradio
@@ -60,7 +60,7 @@ h11==0.16.0
60
  # httpcore
61
  # uvicorn
62
  hf-transfer==0.1.9
63
- # via command-a-vision-07-2025 (pyproject.toml)
64
  hf-xet==1.1.5
65
  # via huggingface-hub
66
  httpcore==1.0.9
@@ -95,7 +95,7 @@ joblib==1.5.1
95
  lazy-loader==0.4
96
  # via librosa
97
  librosa==0.11.0
98
- # via command-a-vision-07-2025 (pyproject.toml)
99
  llvmlite==0.44.0
100
  # via numba
101
  markdown-it-py==3.0.0
@@ -249,7 +249,7 @@ soundfile==0.13.1
249
  soxr==0.5.0.post1
250
  # via librosa
251
  spaces==0.37.1
252
- # via command-a-vision-07-2025 (pyproject.toml)
253
  starlette==0.46.2
254
  # via
255
  # fastapi
@@ -259,27 +259,27 @@ sympy==1.13.1
259
  threadpoolctl==3.6.0
260
  # via scikit-learn
261
  timm==1.0.16
262
- # via command-a-vision-07-2025 (pyproject.toml)
263
  tokenizers==0.21.2
264
  # via transformers
265
  tomlkit==0.13.3
266
  # via gradio
267
  torch==2.5.1
268
  # via
269
- # command-a-vision-07-2025 (pyproject.toml)
270
  # accelerate
271
  # timm
272
  # torchvision
273
  torchvision==0.20.1
274
  # via
275
- # command-a-vision-07-2025 (pyproject.toml)
276
  # timm
277
  tqdm==4.67.1
278
  # via
279
  # huggingface-hub
280
  # transformers
281
  transformers==4.53.0
282
- # via command-a-vision-07-2025 (pyproject.toml)
283
  triton==3.1.0
284
  # via torch
285
  typer==0.16.0
 
1
  # This file was autogenerated by uv via the following command:
2
  # uv pip compile pyproject.toml -o requirements.txt
3
  accelerate==1.8.1
4
+ # via command-a-reasoning-08-2025 (pyproject.toml)
5
  aiofiles==24.1.0
6
  # via gradio
7
  annotated-types==0.7.0
 
14
  audioread==3.0.1
15
  # via librosa
16
  av==14.4.0
17
+ # via command-a-reasoning-08-2025 (pyproject.toml)
18
  certifi==2025.6.15
19
  # via
20
  # httpcore
 
49
  # torch
50
  gradio==5.34.2
51
  # via
52
+ # command-a-reasoning-08-2025 (pyproject.toml)
53
  # spaces
54
  gradio-client==1.10.3
55
  # via gradio
 
60
  # httpcore
61
  # uvicorn
62
  hf-transfer==0.1.9
63
+ # via command-a-reasoning-08-2025 (pyproject.toml)
64
  hf-xet==1.1.5
65
  # via huggingface-hub
66
  httpcore==1.0.9
 
95
  lazy-loader==0.4
96
  # via librosa
97
  librosa==0.11.0
98
+ # via command-a-reasoning-08-2025 (pyproject.toml)
99
  llvmlite==0.44.0
100
  # via numba
101
  markdown-it-py==3.0.0
 
249
  soxr==0.5.0.post1
250
  # via librosa
251
  spaces==0.37.1
252
+ # via command-a-reasoning-08-2025 (pyproject.toml)
253
  starlette==0.46.2
254
  # via
255
  # fastapi
 
259
  threadpoolctl==3.6.0
260
  # via scikit-learn
261
  timm==1.0.16
262
+ # via command-a-reasoning-08-2025 (pyproject.toml)
263
  tokenizers==0.21.2
264
  # via transformers
265
  tomlkit==0.13.3
266
  # via gradio
267
  torch==2.5.1
268
  # via
269
+ # command-a-reasoning-08-2025 (pyproject.toml)
270
  # accelerate
271
  # timm
272
  # torchvision
273
  torchvision==0.20.1
274
  # via
275
+ # command-a-reasoning-08-2025 (pyproject.toml)
276
  # timm
277
  tqdm==4.67.1
278
  # via
279
  # huggingface-hub
280
  # transformers
281
  transformers==4.53.0
282
+ # via command-a-reasoning-08-2025 (pyproject.toml)
283
  triton==3.1.0
284
  # via torch
285
  typer==0.16.0
uv.lock CHANGED
@@ -321,7 +321,7 @@ wheels = [
321
  ]
322
 
323
  [[package]]
324
- name = "command-a-vision-07-2025"
325
  version = "0.1.0"
326
  source = { virtual = "." }
327
  dependencies = [
 
321
  ]
322
 
323
  [[package]]
324
+ name = "command-a-reasoning-07-2025"
325
  version = "0.1.0"
326
  source = { virtual = "." }
327
  dependencies = [