kimhyunwoo commited on
Commit
9c61a12
·
verified ·
1 Parent(s): 4581cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +229 -51
app.py CHANGED
@@ -9,12 +9,16 @@ from torchvision import datasets, transforms
9
  import numpy as np
10
  import threading
11
 
12
- app = FastAPI(title="3D CNN Visualizer + MNIST", version="0.2.0")
13
 
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
15
  MODEL = None
 
 
16
  TRAINING_DONE = False
17
  TRAINING_ERROR = None
 
18
 
19
 
20
  class SimpleCNN(nn.Module):
@@ -29,8 +33,8 @@ class SimpleCNN(nn.Module):
29
  """
30
  def __init__(self):
31
  super().__init__()
32
- self.conv1 = nn.Conv2d(1, 4, kernel_size=5) # 28 -> 24
33
- self.conv2 = nn.Conv2d(4, 8, kernel_size=5) # 12 -> 8
34
  self.pool = nn.MaxPool2d(2, 2)
35
  self.fc = nn.Linear(8 * 4 * 4, 10)
36
 
@@ -59,14 +63,15 @@ class SimpleCNN(nn.Module):
59
  return x
60
 
61
 
62
- def train_model():
63
- global MODEL, TRAINING_DONE, TRAINING_ERROR
 
64
  try:
65
  transform = transforms.ToTensor()
66
  train_dataset = datasets.MNIST(
67
  root="./data", train=True, download=True, transform=transform
68
  )
69
- subset_size = min(10000, len(train_dataset))
70
  train_subset = torch.utils.data.Subset(train_dataset, list(range(subset_size)))
71
  loader = DataLoader(train_subset, batch_size=128, shuffle=True)
72
 
@@ -75,7 +80,7 @@ def train_model():
75
  criterion = nn.CrossEntropyLoss()
76
 
77
  model.train()
78
- epochs = 1
79
  for _ in range(epochs):
80
  for images, labels in loader:
81
  images, labels = images.to(DEVICE), labels.to(DEVICE)
@@ -86,21 +91,29 @@ def train_model():
86
  optimizer.step()
87
 
88
  model.eval()
89
- MODEL = model
 
 
 
90
  TRAINING_DONE = True
91
  except Exception as e:
92
  TRAINING_ERROR = str(e)
93
  TRAINING_DONE = False
94
 
95
 
96
- # 학습을 백그라운드에서 시작
97
- threading.Thread(target=train_model, daemon=True).start()
98
 
99
 
100
  class PredictRequest(BaseModel):
101
  pixels: list[float] # 28*28 = 784
102
 
103
 
 
 
 
 
 
104
  @app.get("/", response_class=HTMLResponse)
105
  async def index():
106
  return HTML_PAGE
@@ -115,6 +128,26 @@ async def status():
115
  }
116
 
117
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
  @app.post("/predict")
119
  async def predict(req: PredictRequest):
120
  if TRAINING_ERROR:
@@ -136,28 +169,74 @@ async def predict(req: PredictRequest):
136
 
137
  arr = np.array(req.pixels, dtype=np.float32).reshape(1, 1, 28, 28)
138
  x = torch.from_numpy(arr).to(DEVICE)
139
- with torch.no_grad():
140
- logits, acts = MODEL(x, return_activations=True)
141
- probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
142
 
143
- conv1 = acts["conv1"].cpu().numpy()[0].tolist() # [4,24,24]
144
- pool1 = acts["pool1"].cpu().numpy()[0].tolist() # [4,12,12]
145
- conv2 = acts["conv2"].cpu().numpy()[0].tolist() # [8,8,8]
146
- pool2 = acts["pool2"].cpu().numpy()[0].tolist() # [8,4,4]
147
- flat = acts["flat"].cpu().numpy()[0].tolist() # [128]
148
 
149
  predicted_class = int(probs.argmax())
150
 
151
  return {
152
  "predicted_class": predicted_class,
153
  "probabilities": probs.tolist(),
154
- "activations": {
155
- "conv1": conv1,
156
- "pool1": pool1,
157
- "conv2": conv2,
158
- "pool2": pool2,
159
- "flat": flat,
160
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  }
162
 
163
 
@@ -544,6 +623,7 @@ HTML_PAGE = r"""
544
  const [statusText, setStatusText] = useState("Checking model status...");
545
  const [lastPrediction, setLastPrediction] = useState(null);
546
  const [padKey, setPadKey] = useState(0);
 
547
 
548
  useEffect(() => {
549
  audio.init();
@@ -577,13 +657,17 @@ HTML_PAGE = r"""
577
 
578
  const delay = ms => new Promise(r => setTimeout(r, ms));
579
 
 
 
 
 
 
 
580
  const run = async () => {
581
- if(processing) return;
582
  setProcessing(true);
583
 
584
- // 1) 입력 플래튼
585
- const flat = [];
586
- for(let y=0; y<28; y++) for(let x=0; x<28; x++) flat.push(inputData[y][x]);
587
 
588
  let probs = null;
589
  let predClass = null;
@@ -601,7 +685,8 @@ HTML_PAGE = r"""
601
  acts = json.activations;
602
  setLastPrediction({
603
  cls: predClass,
604
- conf: probs[predClass]
 
605
  });
606
  } else {
607
  alert("Error: " + (json.error || "Unknown"));
@@ -615,15 +700,13 @@ HTML_PAGE = r"""
615
  return;
616
  }
617
 
618
- // 2) PyTorch activations 시각화용 포맷 변환
619
- // conv1 / pool1 / conv2 / pool2 는 그대로 사용 (depth x h x w)
620
  const conv1 = acts.conv1 || [];
621
  const pool1 = acts.pool1 || [];
622
  const conv2 = acts.conv2 || [];
623
  const pool2 = acts.pool2 || [];
624
- const flatVec = acts.flat || []; // length 128
625
 
626
- // flat: 1 x 8 x 16으로 reshape
627
  const flatGrid = [];
628
  for(let y=0; y<8; y++) {
629
  const row = [];
@@ -633,38 +716,105 @@ HTML_PAGE = r"""
633
  }
634
  flatGrid.push(row);
635
  }
636
- const flatData = [flatGrid]; // depth=1
637
-
638
- // fc: [probability] 형태로 감싸서 시각화
639
  const fcData = [probs.map(p => [p])];
640
 
641
- // 3) 단계별로 activations 상태 업데이트 (애니메이션)
642
  setActivations(prev => ({...prev, input: [inputData]}));
643
- setStep(0); audio.playStep(0); await delay(400);
644
 
645
  setActivations(prev => ({...prev, conv1}));
646
- setStep(1); audio.playStep(1); await delay(400);
647
 
648
  setActivations(prev => ({...prev, pool1}));
649
- setStep(2); audio.playStep(2); await delay(400);
650
 
651
  setActivations(prev => ({...prev, conv2}));
652
- setStep(3); audio.playStep(3); await delay(400);
653
 
654
  setActivations(prev => ({...prev, pool2}));
655
- setStep(4); audio.playStep(4); await delay(400);
656
 
657
  setActivations(prev => ({...prev, flat: flatData}));
658
- setStep(5); audio.playStep(5); await delay(400);
659
 
660
  setActivations(prev => ({...prev, fc: fcData}));
661
  setStep(6); audio.playStep(6);
662
 
663
- await delay(1500);
664
  setProcessing(false);
665
  setStep(-1);
666
  };
667
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  return (
669
  <div className="w-full h-screen relative bg-black font-mono">
670
  <Canvas shadows camera={{ position: [25, 10, 5], fov: 45 }}>
@@ -692,7 +842,7 @@ HTML_PAGE = r"""
692
  <Cpu className="text-neon-green animate-pulse" /> DEEP <span className="text-neon-green">CNN</span>
693
  </h1>
694
  <div className="text-xs text-green-400 mt-2 flex items-center gap-2">
695
- <Activity size={12} /> {processing ? "PROCESSING TENSORS..." : "ONLINE"}
696
  </div>
697
  <div className="text-[10px] text-green-500 mt-1">{statusText}</div>
698
  {lastPrediction && (
@@ -715,15 +865,43 @@ HTML_PAGE = r"""
715
  <div className="text-xs font-bold tracking-widest flex items-center gap-2">
716
  <Scan size={14} /> INPUT SENSOR
717
  </div>
718
- <button onClick={reset} disabled={processing} className="hover:text-white transition-colors">
719
  <RotateCcw size={16} />
720
  </button>
721
  </div>
722
- <DrawingPad key={padKey} data={inputData} onChange={setInputData} disabled={processing} />
723
- <button onClick={run} disabled={processing} className="w-full mt-4 py-3 rounded btn-holo flex justify-center items-center gap-2 font-bold transition-all">
724
- {processing ? <Activity className="animate-spin" size={18} /> : <Play size={18} fill="currentColor" />}
725
- {processing ? 'CALCULATING...' : 'RUN INFERENCE'}
726
  </button>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
727
  </div>
728
 
729
  <div className="hud-panel p-5 hidden md:block rounded-t-xl min-w-[260px] border-b-0">
 
9
  import numpy as np
10
  import threading
11
 
12
+ app = FastAPI(title="3D CNN Visualizer + MNIST + Online Learning", version="0.3.0")
13
 
14
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
+
16
  MODEL = None
17
+ OPTIMIZER = None
18
+ CRITERION = None
19
  TRAINING_DONE = False
20
  TRAINING_ERROR = None
21
+ MODEL_LOCK = threading.Lock()
22
 
23
 
24
  class SimpleCNN(nn.Module):
 
33
  """
34
  def __init__(self):
35
  super().__init__()
36
+ self.conv1 = nn.Conv2d(1, 4, kernel_size=5)
37
+ self.conv2 = nn.Conv2d(4, 8, kernel_size=5)
38
  self.pool = nn.MaxPool2d(2, 2)
39
  self.fc = nn.Linear(8 * 4 * 4, 10)
40
 
 
63
  return x
64
 
65
 
66
+ def train_model_initial():
67
+ """Offline initial training on MNIST subset."""
68
+ global MODEL, OPTIMIZER, CRITERION, TRAINING_DONE, TRAINING_ERROR
69
  try:
70
  transform = transforms.ToTensor()
71
  train_dataset = datasets.MNIST(
72
  root="./data", train=True, download=True, transform=transform
73
  )
74
+ subset_size = min(20000, len(train_dataset))
75
  train_subset = torch.utils.data.Subset(train_dataset, list(range(subset_size)))
76
  loader = DataLoader(train_subset, batch_size=128, shuffle=True)
77
 
 
80
  criterion = nn.CrossEntropyLoss()
81
 
82
  model.train()
83
+ epochs = 2
84
  for _ in range(epochs):
85
  for images, labels in loader:
86
  images, labels = images.to(DEVICE), labels.to(DEVICE)
 
91
  optimizer.step()
92
 
93
  model.eval()
94
+ with MODEL_LOCK:
95
+ MODEL = model
96
+ OPTIMIZER = optimizer
97
+ CRITERION = criterion
98
  TRAINING_DONE = True
99
  except Exception as e:
100
  TRAINING_ERROR = str(e)
101
  TRAINING_DONE = False
102
 
103
 
104
+ # start initial training in background
105
+ threading.Thread(target=train_model_initial, daemon=True).start()
106
 
107
 
108
  class PredictRequest(BaseModel):
109
  pixels: list[float] # 28*28 = 784
110
 
111
 
112
+ class FeedbackRequest(BaseModel):
113
+ pixels: list[float] # 28*28
114
+ label: int # 0~9
115
+
116
+
117
  @app.get("/", response_class=HTMLResponse)
118
  async def index():
119
  return HTML_PAGE
 
128
  }
129
 
130
 
131
+ def _forward_with_activations(x_tensor: torch.Tensor):
132
+ """Run model forward and return probs + activations as python lists."""
133
+ logits, acts = MODEL(x_tensor, return_activations=True)
134
+ probs = torch.softmax(logits, dim=1).cpu().numpy()[0]
135
+
136
+ conv1 = acts["conv1"].detach().cpu().numpy()[0].tolist() # [4,24,24]
137
+ pool1 = acts["pool1"].detach().cpu().numpy()[0].tolist() # [4,12,12]
138
+ conv2 = acts["conv2"].detach().cpu().numpy()[0].tolist() # [8,8,8]
139
+ pool2 = acts["pool2"].detach().cpu().numpy()[0].tolist() # [8,4,4]
140
+ flat = acts["flat"].detach().cpu().numpy()[0].tolist() # [128]
141
+
142
+ return probs, {
143
+ "conv1": conv1,
144
+ "pool1": pool1,
145
+ "conv2": conv2,
146
+ "pool2": pool2,
147
+ "flat": flat,
148
+ }
149
+
150
+
151
  @app.post("/predict")
152
  async def predict(req: PredictRequest):
153
  if TRAINING_ERROR:
 
169
 
170
  arr = np.array(req.pixels, dtype=np.float32).reshape(1, 1, 28, 28)
171
  x = torch.from_numpy(arr).to(DEVICE)
 
 
 
172
 
173
+ with MODEL_LOCK:
174
+ MODEL.eval()
175
+ with torch.no_grad():
176
+ probs, acts = _forward_with_activations(x)
 
177
 
178
  predicted_class = int(probs.argmax())
179
 
180
  return {
181
  "predicted_class": predicted_class,
182
  "probabilities": probs.tolist(),
183
+ "activations": acts,
184
+ }
185
+
186
+
187
+ @app.post("/feedback")
188
+ async def feedback(req: FeedbackRequest):
189
+ """Online learning: single-sample update with user-labeled digit."""
190
+ if TRAINING_ERROR:
191
+ return JSONResponse(
192
+ status_code=500,
193
+ content={"error": "Training failed", "detail": TRAINING_ERROR},
194
+ )
195
+ if not TRAINING_DONE or MODEL is None or OPTIMIZER is None or CRITERION is None:
196
+ return JSONResponse(
197
+ status_code=503,
198
+ content={"error": "Model not ready yet."},
199
+ )
200
+
201
+ if len(req.pixels) != 28 * 28:
202
+ return JSONResponse(
203
+ status_code=400,
204
+ content={"error": "pixels must have length 784 (28x28)"},
205
+ )
206
+ if not (0 <= req.label <= 9):
207
+ return JSONResponse(
208
+ status_code=400,
209
+ content={"error": "label must be between 0 and 9"},
210
+ )
211
+
212
+ arr = np.array(req.pixels, dtype=np.float32).reshape(1, 1, 28, 28)
213
+ x = torch.from_numpy(arr).to(DEVICE)
214
+ y = torch.tensor([req.label], dtype=torch.long, device=DEVICE)
215
+
216
+ with MODEL_LOCK:
217
+ MODEL.train()
218
+ loss_val = None
219
+ # do a few small gradient steps on this sample
220
+ for _ in range(3):
221
+ OPTIMIZER.zero_grad()
222
+ logits = MODEL(x)
223
+ loss = CRITERION(logits, y)
224
+ loss.backward()
225
+ OPTIMIZER.step()
226
+ loss_val = float(loss.item())
227
+ MODEL.eval()
228
+ # re-run forward to get updated prediction & activations
229
+ with torch.no_grad():
230
+ probs, acts = _forward_with_activations(x)
231
+
232
+ predicted_class = int(probs.argmax())
233
+
234
+ return {
235
+ "status": "ok",
236
+ "loss": loss_val,
237
+ "predicted_class": predicted_class,
238
+ "probabilities": probs.tolist(),
239
+ "activations": acts,
240
  }
241
 
242
 
 
623
  const [statusText, setStatusText] = useState("Checking model status...");
624
  const [lastPrediction, setLastPrediction] = useState(null);
625
  const [padKey, setPadKey] = useState(0);
626
+ const [feedbackBusy, setFeedbackBusy] = useState(false);
627
 
628
  useEffect(() => {
629
  audio.init();
 
657
 
658
  const delay = ms => new Promise(r => setTimeout(r, ms));
659
 
660
+ const flattenInput = (grid) => {
661
+ const flat = [];
662
+ for(let y=0; y<28; y++) for(let x=0; x<28; x++) flat.push(grid[y][x]);
663
+ return flat;
664
+ };
665
+
666
  const run = async () => {
667
+ if(processing || feedbackBusy) return;
668
  setProcessing(true);
669
 
670
+ const flat = flattenInput(inputData);
 
 
671
 
672
  let probs = null;
673
  let predClass = null;
 
685
  acts = json.activations;
686
  setLastPrediction({
687
  cls: predClass,
688
+ conf: probs[predClass],
689
+ probs,
690
  });
691
  } else {
692
  alert("Error: " + (json.error || "Unknown"));
 
700
  return;
701
  }
702
 
703
+ // convert activations for visualization
 
704
  const conv1 = acts.conv1 || [];
705
  const pool1 = acts.pool1 || [];
706
  const conv2 = acts.conv2 || [];
707
  const pool2 = acts.pool2 || [];
708
+ const flatVec = acts.flat || [];
709
 
 
710
  const flatGrid = [];
711
  for(let y=0; y<8; y++) {
712
  const row = [];
 
716
  }
717
  flatGrid.push(row);
718
  }
719
+ const flatData = [flatGrid];
 
 
720
  const fcData = [probs.map(p => [p])];
721
 
 
722
  setActivations(prev => ({...prev, input: [inputData]}));
723
+ setStep(0); audio.playStep(0); await delay(250);
724
 
725
  setActivations(prev => ({...prev, conv1}));
726
+ setStep(1); audio.playStep(1); await delay(250);
727
 
728
  setActivations(prev => ({...prev, pool1}));
729
+ setStep(2); audio.playStep(2); await delay(250);
730
 
731
  setActivations(prev => ({...prev, conv2}));
732
+ setStep(3); audio.playStep(3); await delay(250);
733
 
734
  setActivations(prev => ({...prev, pool2}));
735
+ setStep(4); audio.playStep(4); await delay(250);
736
 
737
  setActivations(prev => ({...prev, flat: flatData}));
738
+ setStep(5); audio.playStep(5); await delay(250);
739
 
740
  setActivations(prev => ({...prev, fc: fcData}));
741
  setStep(6); audio.playStep(6);
742
 
743
+ await delay(1000);
744
  setProcessing(false);
745
  setStep(-1);
746
  };
747
 
748
+ const sendFeedback = async (correctLabel) => {
749
+ if(!lastPrediction) return;
750
+ if(processing || feedbackBusy) return;
751
+ setFeedbackBusy(true);
752
+ setStatusText(`Online update with label ${correctLabel} ...`);
753
+
754
+ const flat = flattenInput(inputData);
755
+
756
+ try {
757
+ const res = await fetch("/feedback", {
758
+ method: "POST",
759
+ headers: { "Content-Type": "application/json" },
760
+ body: JSON.stringify({ pixels: flat, label: correctLabel }),
761
+ });
762
+ const json = await res.json();
763
+ if(!res.ok) {
764
+ alert("Feedback error: " + (json.error || "Unknown"));
765
+ setFeedbackBusy(false);
766
+ return;
767
+ }
768
+
769
+ const probs = json.probabilities;
770
+ const acts = json.activations;
771
+ const predClass = json.predicted_class;
772
+
773
+ setLastPrediction({
774
+ cls: predClass,
775
+ conf: probs[predClass],
776
+ probs,
777
+ });
778
+
779
+ // update visualization with new activations
780
+ const conv1 = acts.conv1 || [];
781
+ const pool1 = acts.pool1 || [];
782
+ const conv2 = acts.conv2 || [];
783
+ const pool2 = acts.pool2 || [];
784
+ const flatVec = acts.flat || [];
785
+
786
+ const flatGrid = [];
787
+ for(let y=0; y<8; y++) {
788
+ const row = [];
789
+ for(let x=0; x<16; x++) {
790
+ const idx = y * 16 + x;
791
+ row.push(flatVec[idx] || 0);
792
+ }
793
+ flatGrid.push(row);
794
+ }
795
+ const flatData = [flatGrid];
796
+ const fcData = [probs.map(p => [p])];
797
+
798
+ setActivations(prev => ({
799
+ ...prev,
800
+ input: [inputData],
801
+ conv1,
802
+ pool1,
803
+ conv2,
804
+ pool2,
805
+ flat: flatData,
806
+ fc: fcData,
807
+ }));
808
+
809
+ setStatusText(`Online updated with label ${correctLabel} (loss ~${json.loss.toFixed(4)})`);
810
+ } catch (e) {
811
+ console.error(e);
812
+ alert("Feedback request failed.");
813
+ } finally {
814
+ setFeedbackBusy(false);
815
+ }
816
+ };
817
+
818
  return (
819
  <div className="w-full h-screen relative bg-black font-mono">
820
  <Canvas shadows camera={{ position: [25, 10, 5], fov: 45 }}>
 
842
  <Cpu className="text-neon-green animate-pulse" /> DEEP <span className="text-neon-green">CNN</span>
843
  </h1>
844
  <div className="text-xs text-green-400 mt-2 flex items-center gap-2">
845
+ <Activity size={12} /> {(processing || feedbackBusy) ? "PROCESSING / UPDATING..." : "ONLINE"}
846
  </div>
847
  <div className="text-[10px] text-green-500 mt-1">{statusText}</div>
848
  {lastPrediction && (
 
865
  <div className="text-xs font-bold tracking-widest flex items-center gap-2">
866
  <Scan size={14} /> INPUT SENSOR
867
  </div>
868
+ <button onClick={reset} disabled={processing || feedbackBusy} className="hover:text-white transition-colors">
869
  <RotateCcw size={16} />
870
  </button>
871
  </div>
872
+ <DrawingPad key={padKey} data={inputData} onChange={setInputData} disabled={processing || feedbackBusy} />
873
+ <button onClick={run} disabled={processing || feedbackBusy} className="w-full mt-4 py-3 rounded btn-holo flex justify-center items-center gap-2 font-bold transition-all">
874
+ {(processing || feedbackBusy) ? <Activity className="animate-spin" size={18} /> : <Play size={18} fill="currentColor" />}
875
+ {(processing || feedbackBusy) ? 'RUNNING / UPDATING...' : 'RUN INFERENCE'}
876
  </button>
877
+
878
+ {lastPrediction && (
879
+ <div className="mt-4">
880
+ <div className="text-[11px] text-green-400 mb-2">
881
+ Select the <span className="font-bold text-neon-green">correct digit</span> for online learning:
882
+ </div>
883
+ <div className="grid grid-cols-5 gap-2">
884
+ {Array.from({length:10}).map((_, d) => {
885
+ const isPred = lastPrediction.cls === d;
886
+ return (
887
+ <button
888
+ key={d}
889
+ onClick={() => sendFeedback(d)}
890
+ disabled={feedbackBusy || processing}
891
+ className={
892
+ "border rounded px-2 py-1 text-xs " +
893
+ (isPred
894
+ ? "border-neon-green text-neon-green bg-black/40"
895
+ : "border-green-800 text-green-400 hover:border-neon-green hover:text-neon-green")
896
+ }
897
+ >
898
+ {d}
899
+ </button>
900
+ );
901
+ })}
902
+ </div>
903
+ </div>
904
+ )}
905
  </div>
906
 
907
  <div className="hud-panel p-5 hidden md:block rounded-t-xl min-w-[260px] border-b-0">