make789 commited on
Commit
8d0b826
·
verified ·
1 Parent(s): 7f46686

Upload ocr_service.py

Browse files
Files changed (1) hide show
  1. ocr_service.py +38 -54
ocr_service.py CHANGED
@@ -133,7 +133,9 @@ def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
133
  Download DeepSeek-OCR to a local dir, patch for CPU:
134
  - remove hardcoded .cuda()
135
  - force float32 (strip .bfloat16() / .to(torch.bfloat16))
136
- - disable torch.autocast("cuda", ...) blocks
 
 
137
 
138
  Return local path for from_pretrained(...).
139
 
@@ -153,71 +155,53 @@ def _download_and_patch_model_locally(model_id: str, revision: str) -> str:
153
  print(f" ✅ Downloaded to: {local_dir}")
154
  local_dir = Path(local_dir)
155
 
156
- def ensure_imports(src: str) -> str:
157
- """Ensure torch and contextlib are imported"""
158
- if "import torch" not in src and "from torch" not in src:
159
- src = "import torch\n" + src
160
- if "import contextlib" not in src:
161
- # Add after torch import
162
- if "import torch" in src:
163
- src = src.replace("import torch", "import torch\nimport contextlib")
164
- elif "from torch" in src:
165
- src = src.replace("from torch", "import contextlib\nfrom torch")
166
- else:
167
- # Add at top if no torch import
168
- lines = src.split('\n')
169
- for i, line in enumerate(lines):
170
- if line.startswith("import ") or line.startswith("from "):
171
- continue
172
- else:
173
- lines.insert(i, "import contextlib")
174
- break
175
- src = '\n'.join(lines)
176
- return src
177
-
178
- def patch_text(src: str) -> str:
179
- """Patch text to remove CUDA/BF16 and force FP32"""
180
- # A) Remove hardcoded CUDA device moves
181
- src = src.replace(".unsqueeze(-1).cuda()", ".unsqueeze(-1)")
182
- src = src.replace("input_ids.unsqueeze(0).cuda()", "input_ids.unsqueeze(0)")
183
- src = src.replace("(images_crop.cuda(), images_ori.cuda())", "(images_crop, images_ori)")
184
- src = src.replace("images_seq_mask = images_seq_mask.unsqueeze(0).cuda()",
185
  "images_seq_mask = images_seq_mask.unsqueeze(0)")
186
- src = src.replace("input_ids.unsqueeze(0).cuda().shape[1]",
187
  "input_ids.unsqueeze(0).shape[1]")
188
 
189
- # B) Force FP32 (strip BF16 casts)
190
- src = re.sub(r"\.bfloat16\(\)", ".float()", src)
191
- src = re.sub(r"\.to\(\s*torch\.bfloat16\s*\)", ".to(torch.float32)", src)
192
- src = re.sub(r"\.to\(\s*dtype\s*=\s*torch\.bfloat16\s*\)", ".to(dtype=torch.float32)", src)
193
 
194
- # C) Disable CUDA autocast blocks
195
- src = ensure_imports(src)
196
- # Match both torch.autocast("cuda", ...) and torch.autocast(device_type="cuda", ...)
197
- src = re.sub(r'torch\.autocast\(\s*["\']cuda["\']\s*,[^)]*\)',
198
- "contextlib.nullcontext()", src)
199
- src = re.sub(r'torch\.autocast\(\s*device_type\s*=\s*["\']cuda["\'][^)]*\)',
200
- "contextlib.nullcontext()", src)
201
 
202
- return src
 
 
 
 
203
 
204
  # Patch both files where they may appear
205
  targets = list(local_dir.rglob("modeling_deepseekocr.py")) + \
206
  list(local_dir.rglob("deepencoder.py"))
207
 
208
  if not targets:
209
- raise RuntimeError("Could not locate DeepSeek-OCR source files to patch")
210
 
211
- for p in targets:
212
- print(f" 🔍 Found file: {p.name}")
213
- txt = p.read_text(encoding="utf-8")
214
- new = patch_text(txt)
215
-
216
- if new != txt:
217
- p.write_text(new, encoding="utf-8")
218
- print(f" ✅ Patched for CPU+FP32: {p.name}")
219
- else:
220
- print(f" ℹ️ Already CPU-patched: {p.name}")
 
 
221
 
222
  return str(local_dir)
223
 
 
133
  Download DeepSeek-OCR to a local dir, patch for CPU:
134
  - remove hardcoded .cuda()
135
  - force float32 (strip .bfloat16() / .to(torch.bfloat16))
136
+
137
+ Minimal patcher that avoids indentation issues by NOT touching autocast blocks.
138
+ On CPU, torch.autocast is auto-disabled anyway, so we leave it alone.
139
 
140
  Return local path for from_pretrained(...).
141
 
 
155
  print(f" ✅ Downloaded to: {local_dir}")
156
  local_dir = Path(local_dir)
157
 
158
+ def patch_file(p: Path):
159
+ """Minimal patch - only string replacements, no indentation changes"""
160
+ txt0 = p.read_text(encoding="utf-8")
161
+ txt = txt0
162
+
163
+ # A) Remove hardcoded CUDA device moves (CPU-safe)
164
+ txt = txt.replace(".unsqueeze(-1).cuda()", ".unsqueeze(-1)")
165
+ txt = txt.replace("input_ids.unsqueeze(0).cuda()", "input_ids.unsqueeze(0)")
166
+ txt = txt.replace("(images_crop.cuda(), images_ori.cuda())", "(images_crop, images_ori)")
167
+ txt = txt.replace("images_seq_mask = images_seq_mask.unsqueeze(0).cuda()",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  "images_seq_mask = images_seq_mask.unsqueeze(0)")
169
+ txt = txt.replace("input_ids.unsqueeze(0).cuda().shape[1]",
170
  "input_ids.unsqueeze(0).shape[1]")
171
 
172
+ # B) Force FP32 (fix BF16 vs FP32), pattern-safe (no newlines/indentation changes)
173
+ txt = re.sub(r"\.bfloat16\(\)", ".float()", txt)
174
+ txt = re.sub(r"\.to\(\s*torch\.bfloat16\s*\)", ".to(torch.float32)", txt)
175
+ txt = re.sub(r"\.to\(\s*dtype\s*=\s*torch\.bfloat16\s*\)", ".to(dtype=torch.float32)", txt)
176
 
177
+ # Note: We do NOT touch torch.autocast() blocks - on CPU they're auto-disabled
178
+ # and modifying them risks breaking indentation/syntax
 
 
 
 
 
179
 
180
+ if txt != txt0:
181
+ p.write_text(txt, encoding="utf-8")
182
+ print(f" ✅ Patched CPU/FP32: {p.name}")
183
+ else:
184
+ print(f" ℹ️ Already CPU/FP32-safe: {p.name}")
185
 
186
  # Patch both files where they may appear
187
  targets = list(local_dir.rglob("modeling_deepseekocr.py")) + \
188
  list(local_dir.rglob("deepencoder.py"))
189
 
190
  if not targets:
191
+ raise RuntimeError("Could not find DeepSeek-OCR source files to patch")
192
 
193
+ for f in targets:
194
+ print(f" 🔍 Found file: {f.name}")
195
+ patch_file(f)
196
+
197
+ # Optional: compile check to catch syntax errors early
198
+ try:
199
+ import py_compile
200
+ for f in targets:
201
+ py_compile.compile(str(f), doraise=True)
202
+ print(f" Syntax check passed for {len(targets)} file(s)")
203
+ except py_compile.PyCompileError as e:
204
+ raise RuntimeError(f"Syntax check failed after patch: {e}")
205
 
206
  return str(local_dir)
207