File size: 6,583 Bytes
6cc5a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
"""Local FastAPI service: STEP-file-in, classification-out. One process, models loaded once.

Designed to be launched by a C# host app as a child process. The service binds to
127.0.0.1 only; never exposes a network-facing port.

Startup protocol:
  - Bind a free port (default: OS-assigned via --port 0).
  - Once models are loaded and the server is accepting requests, write
    `READY port=<PORT>` to stdout (single line). The host reads that line to learn
    where to connect.
  - On `POST /shutdown` (or SIGTERM), drain in-flight requests and exit.

Endpoints:
  GET  /health                 -> service / model state
  POST /classify {step_path}   -> single STEP classification
  POST /classify_batch {paths} -> multiple in one call (sequential, single-process)
  POST /shutdown               -> graceful exit
"""
from __future__ import annotations
import argparse
import os
import shutil
import sys
import tempfile
import threading
import time
from pathlib import Path
from typing import List, Optional

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn

from heg_brep import (
    DEFAULT_PASS1_MODEL, DEFAULT_ELBOW_MODEL, DEFAULT_TEE_MODEL,
)
from heg_brep.inference import LoadedModel, TwoPassClassifier
from heg_brep.extraction import extract_step_to_npz

# Eagerly import the BRepExtractor pipeline at startup. Without this, the first
# /classify call pays ~20-30s of OCC + occwl + igl cold-import cost — terrible
# UX for an interactive viewer.
import pipeline.extract_brep_extractor_data_from_step  # noqa: F401


class State:
    classifier: Optional[TwoPassClassifier] = None
    started_at: float = 0.0
    device: str = "cpu"
    pass2_min_conf: float = 0.85
    pass2_tau: float = 0.0


app = FastAPI(title="HEG BRep component identification")


class ClassifyRequest(BaseModel):
    step_path: str
    # Optional override: persist the NPZ to disk for inspection.
    npz_keep_dir: Optional[str] = None


class ClassifyBatchRequest(BaseModel):
    step_paths: List[str]
    npz_keep_dir: Optional[str] = None


@app.get("/health")
def health():
    return {
        "status": "ok",
        "models_loaded": State.classifier is not None,
        "device": State.device,
        "uptime_sec": round(time.time() - State.started_at, 2),
    }


def _classify_one(step_path: str, npz_keep_dir: Optional[str]) -> dict:
    if State.classifier is None:
        raise HTTPException(status_code=503, detail="models not loaded yet")
    sp = Path(step_path).expanduser()
    if not sp.exists():
        return {"step_path": str(sp), "status": "error", "error": "step_path not found"}
    temp_npz_dir = npz_keep_dir is None
    out_dir = Path(npz_keep_dir).expanduser().resolve() if npz_keep_dir \
              else Path(tempfile.mkdtemp(prefix="heg_brep_npz_"))
    try:
        try:
            npz = extract_step_to_npz(sp, out_dir)
        except Exception as exc:
            return {"step_path": str(sp), "status": "extraction_failed", "error": str(exc)[:500]}
        try:
            result = State.classifier.classify_npz(npz)
        except Exception as exc:
            return {"step_path": str(sp), "status": "inference_failed", "error": str(exc)[:500],
                    "npz_path": str(npz)}
        result.update({"step_path": str(sp), "status": "ok", "npz_path": str(npz)})
        return result
    finally:
        if temp_npz_dir:
            shutil.rmtree(out_dir, ignore_errors=True)


@app.post("/classify")
def classify(req: ClassifyRequest):
    return _classify_one(req.step_path, req.npz_keep_dir)


@app.post("/classify_batch")
def classify_batch(req: ClassifyBatchRequest):
    return {"results": [_classify_one(p, req.npz_keep_dir) for p in req.step_paths]}


@app.post("/shutdown")
def shutdown():
    threading.Thread(target=lambda: (time.sleep(0.2), os._exit(0)), daemon=True).start()
    return {"status": "shutting_down"}


def _load_models(args) -> None:
    pass1 = LoadedModel(Path(args.pass1_model), device=args.device)
    elbow = LoadedModel(Path(args.elbow_model), device=args.device) if Path(args.elbow_model).exists() else None
    tee   = LoadedModel(Path(args.tee_model),   device=args.device) if Path(args.tee_model).exists()   else None
    State.classifier = TwoPassClassifier(
        pass1=pass1, elbow=elbow, tee=tee,
        pass2_min_conf=args.pass2_min_conf, pass2_tau=args.pass2_tau,
    )
    State.device = args.device
    State.pass2_min_conf = args.pass2_min_conf
    State.pass2_tau = args.pass2_tau


def parse_args() -> argparse.Namespace:
    ap = argparse.ArgumentParser(description="HEG BRep classification service")
    ap.add_argument("--host", default="127.0.0.1")
    ap.add_argument("--port", type=int, default=0,
                    help="Port to bind. 0 = let the OS pick. The chosen port is "
                         "printed to stdout as a single line `READY port=<N>`.")
    ap.add_argument("--pass1_model", default=str(DEFAULT_PASS1_MODEL))
    ap.add_argument("--elbow_model", default=str(DEFAULT_ELBOW_MODEL))
    ap.add_argument("--tee_model",   default=str(DEFAULT_TEE_MODEL))
    ap.add_argument("--device", default="cpu", choices=["cpu", "cuda"])
    ap.add_argument("--pass2_min_conf", type=float, default=0.85)
    ap.add_argument("--pass2_tau", type=float, default=0.0)
    return ap.parse_args()


def main() -> int:
    args = parse_args()
    State.started_at = time.time()
    print(f"[heg_brep] loading models from:", file=sys.stderr)
    print(f"  pass1: {args.pass1_model}", file=sys.stderr)
    print(f"  elbow: {args.elbow_model}", file=sys.stderr)
    print(f"  tee  : {args.tee_model}",   file=sys.stderr)
    print(f"  device: {args.device}",     file=sys.stderr)
    t0 = time.time()
    _load_models(args)
    print(f"[heg_brep] models loaded in {time.time() - t0:.1f}s", file=sys.stderr)

    # Bind socket up front so we can tell the parent which port we got.
    import socket
    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind((args.host, int(args.port)))
    bound_port = sock.getsockname()[1]
    print(f"READY port={bound_port}", flush=True)

    config = uvicorn.Config(app=app, host=args.host, port=bound_port, log_level="warning")
    server = uvicorn.Server(config)
    # uvicorn doesn't accept an already-bound socket in the simple Config API; close
    # ours and let uvicorn rebind. Race window is fine because we're loopback only.
    sock.close()
    server.run()
    return 0


if __name__ == "__main__":
    sys.exit(main())