aeb56 commited on
Commit
b705945
·
1 Parent(s): ef25cbe

Workaround flash-attn: create fake module with PyTorch fallback attention

Browse files
Files changed (2) hide show
  1. app.py +55 -6
  2. requirements.txt +0 -3
app.py CHANGED
@@ -215,15 +215,63 @@ class ChatBot:
215
  logs += f"⏱️ Estimated time: 30-60 minutes\n\n"
216
  yield status_table, logs
217
 
218
- # Run lm_eval with optimized memory settings
219
- # Note: We use parallelize=True to distribute across GPUs instead of device_map in model_args
220
- # We need to install flash-attn for this model to work properly
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  cmd = [
222
  "lm_eval",
223
  "--model", "hf",
224
  "--model_args", f"pretrained={MODEL_NAME},trust_remote_code=True,dtype=bfloat16,low_cpu_mem_usage=True,parallelize=True",
225
  "--tasks", task_string,
226
- "--batch_size", "1", # Reduced to minimize memory usage
227
  "--output_path", output_dir,
228
  "--log_samples"
229
  ]
@@ -233,13 +281,14 @@ class ChatBot:
233
  logs += "---\n\n### 📜 Live Logs (last 15 lines):\n\n```\n"
234
  yield status_table, logs
235
 
236
- # Run evaluation
237
  process = subprocess.Popen(
238
  cmd,
239
  stdout=subprocess.PIPE,
240
  stderr=subprocess.STDOUT,
241
  text=True,
242
- bufsize=1
 
243
  )
244
 
245
  output_lines = []
 
215
  logs += f"⏱️ Estimated time: 30-60 minutes\n\n"
216
  yield status_table, logs
217
 
218
+ # Create a fake flash_attn package to avoid import errors
219
+ # This will fallback to standard PyTorch attention
220
+ fake_flash_dir = f"/tmp/flash_attn_{timestamp}"
221
+ os.makedirs(fake_flash_dir, exist_ok=True)
222
+
223
+ with open(os.path.join(fake_flash_dir, "__init__.py"), 'w') as f:
224
+ f.write("""
225
+ # Fake flash_attn module that falls back to standard PyTorch attention
226
+ import torch
227
+
228
+ def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
229
+ '''Fallback to standard PyTorch attention (slower but works without flash-attn)'''
230
+ if softmax_scale is None:
231
+ softmax_scale = 1.0 / (q.size(-1) ** 0.5)
232
+
233
+ # Standard attention: softmax(Q @ K.T) @ V
234
+ attn_weights = torch.matmul(q, k.transpose(-2, -1)) * softmax_scale
235
+
236
+ if causal:
237
+ seq_len = attn_weights.size(-1)
238
+ causal_mask = torch.triu(torch.ones(seq_len, seq_len, device=attn_weights.device), diagonal=1).bool()
239
+ attn_weights = attn_weights.masked_fill(causal_mask, float('-inf'))
240
+
241
+ attn_weights = torch.softmax(attn_weights, dim=-1)
242
+
243
+ if dropout_p > 0:
244
+ attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout_p)
245
+
246
+ output = torch.matmul(attn_weights, v)
247
+ return output, None # Return None for attention weights
248
+
249
+ def flash_attn_varlen_func(*args, **kwargs):
250
+ return flash_attn_func(*args, **kwargs)
251
+
252
+ __version__ = "2.5.0"
253
+ """)
254
+
255
+ # Add fake package to Python path for subprocess
256
+ import sys
257
+ if f"/tmp" not in sys.path:
258
+ sys.path.insert(0, "/tmp")
259
+
260
+ # Set PYTHONPATH environment variable so subprocess can find fake flash_attn
261
+ env = os.environ.copy()
262
+ pythonpath = env.get('PYTHONPATH', '')
263
+ env['PYTHONPATH'] = f"/tmp:{pythonpath}" if pythonpath else "/tmp"
264
+
265
+ logs += "⚠️ **Note:** Using fallback PyTorch attention (slower than flash-attn)\n\n"
266
+ yield status_table, logs
267
+
268
+ # Run lm_eval
269
  cmd = [
270
  "lm_eval",
271
  "--model", "hf",
272
  "--model_args", f"pretrained={MODEL_NAME},trust_remote_code=True,dtype=bfloat16,low_cpu_mem_usage=True,parallelize=True",
273
  "--tasks", task_string,
274
+ "--batch_size", "1",
275
  "--output_path", output_dir,
276
  "--log_samples"
277
  ]
 
281
  logs += "---\n\n### 📜 Live Logs (last 15 lines):\n\n```\n"
282
  yield status_table, logs
283
 
284
+ # Run evaluation with custom environment
285
  process = subprocess.Popen(
286
  cmd,
287
  stdout=subprocess.PIPE,
288
  stderr=subprocess.STDOUT,
289
  text=True,
290
+ bufsize=1,
291
+ env=env # Pass custom environment with PYTHONPATH
292
  )
293
 
294
  output_lines = []
requirements.txt CHANGED
@@ -10,9 +10,6 @@ triton>=3.0.0
10
  # Flash Linear Attention (required by Kimi model)
11
  git+https://github.com/sustcsonglin/flash-linear-attention.git@main
12
 
13
- # Flash Attention (required for attention layers)
14
- flash-attn>=2.5.0
15
-
16
  # Evaluation
17
  lm-eval>=0.4.0
18
 
 
10
  # Flash Linear Attention (required by Kimi model)
11
  git+https://github.com/sustcsonglin/flash-linear-attention.git@main
12
 
 
 
 
13
  # Evaluation
14
  lm-eval>=0.4.0
15