File size: 3,453 Bytes
dd850a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
from typing import Dict, Any, Optional
import hashlib
import json
import torch
import pickle
import io

class AttentionCache:
    def __init__(self, max_size: int = 10):
        self.cache = {}
        self.access_order = []
        self.max_size = max_size
    
    def get_key(self, prompt: str, max_tokens: int, model: str, temperature: float = 0.7) -> str:
        """Generate cache key from parameters"""
        data = f"{prompt}_{max_tokens}_{model}_{temperature}"
        return hashlib.md5(data.encode()).hexdigest()
    
    def get(self, key: str) -> Optional[Dict[str, Any]]:
        """Retrieve cached data"""
        if key in self.cache:
            # Move to end (LRU)
            self.access_order.remove(key)
            self.access_order.append(key)
            return self._deserialize(self.cache[key])
        return None
    
    def set(self, key: str, data: Dict[str, Any]):
        """Store data in cache"""
        if len(self.cache) >= self.max_size:
            # Remove least recently used
            oldest = self.access_order.pop(0)
            del self.cache[oldest]
        
        self.cache[key] = self._serialize(data)
        self.access_order.append(key)
    
    def _serialize(self, data: Dict[str, Any]) -> bytes:
        """Serialize data for caching, handling torch tensors"""
        serialized = {}
        for key, value in data.items():
            if isinstance(value, list) and len(value) > 0:
                # Check if it's a list of dicts with tensors (attention matrices)
                if isinstance(value[0], dict) and any(isinstance(v, torch.Tensor) for v in value[0].values()):
                    # Convert tensors to CPU and serialize
                    serialized_list = []
                    for item in value:
                        serialized_item = {}
                        for k, v in item.items():
                            if isinstance(v, torch.Tensor):
                                serialized_item[k] = v.cpu().numpy()
                            else:
                                serialized_item[k] = v
                        serialized_list.append(serialized_item)
                    serialized[key] = serialized_list
                else:
                    serialized[key] = value
            else:
                serialized[key] = value
        
        buffer = io.BytesIO()
        pickle.dump(serialized, buffer)
        return buffer.getvalue()
    
    def _deserialize(self, data: bytes) -> Dict[str, Any]:
        """Deserialize data from cache, restoring torch tensors"""
        buffer = io.BytesIO(data)
        deserialized = pickle.load(buffer)
        
        # Convert numpy arrays back to tensors where needed
        for key, value in deserialized.items():
            if isinstance(value, list) and len(value) > 0:
                if isinstance(value[0], dict):
                    # Check if it contains numpy arrays (was tensors)
                    import numpy as np
                    for item in value:
                        for k, v in item.items():
                            if isinstance(v, np.ndarray):
                                item[k] = torch.from_numpy(v)
        
        return deserialized
    
    def clear(self):
        """Clear the entire cache"""
        self.cache.clear()
        self.access_order.clear()
    
    def size(self) -> int:
        """Get current cache size"""
        return len(self.cache)