| """
|
| interactive_chat.py - Fully Customizable MAP-NEO Mini Interactive Chat Interface
|
| Features: Real-time parameter tuning, conversation memory, context management, multiple responses
|
| """
|
|
|
| import torch
|
| from transformers import AutoTokenizer
|
| from model_neo import NeoMini, NeoMiniConfig
|
| import os
|
| import json
|
| import time
|
| from pathlib import Path
|
| from datetime import datetime
|
| import gc
|
|
|
| class InteractiveChat:
|
| def __init__(self, checkpoint_path="checkpoints/extended_context_model.pt"):
|
| self.model = None
|
| self.tokenizer = None
|
| self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| self.conversation_history = []
|
| self.max_context_length = 16384
|
|
|
|
|
| self.params = {
|
| 'temperature': 0.7,
|
| 'top_k': 50,
|
| 'top_p': 0.9,
|
| 'repetition_penalty': 1.1,
|
| 'max_length': 150,
|
| 'do_sample': True,
|
| 'num_responses': 1
|
| }
|
|
|
| print("π MAP-NEO Mini Interactive Chat Interface")
|
| print("=" * 60)
|
| self.load_model(checkpoint_path)
|
|
|
| def clear_gpu_cache(self):
|
| """Clear GPU memory cache"""
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
| torch.cuda.synchronize()
|
| gc.collect()
|
|
|
| def get_memory_usage(self):
|
| """Get current GPU memory usage"""
|
| if not torch.cuda.is_available():
|
| return "CPU only"
|
|
|
| allocated = torch.cuda.memory_allocated(0) / 1024**3
|
| cached = torch.cuda.memory_reserved(0) / 1024**3
|
| total = torch.cuda.get_device_properties(0).total_memory / 1024**3
|
| return f"{allocated:.2f}GB/{total:.2f}GB (cached: {cached:.2f}GB)"
|
|
|
| def load_model(self, checkpoint_path):
|
| """Load model and tokenizer"""
|
| print(f"π Loading model from {checkpoint_path}...")
|
|
|
| if not os.path.exists(checkpoint_path):
|
| print(f"β Checkpoint not found: {checkpoint_path}")
|
| return False
|
|
|
| try:
|
| checkpoint = torch.load(checkpoint_path, map_location=self.device)
|
|
|
|
|
| if 'config' in checkpoint:
|
| self.max_context_length = checkpoint['config'].get('max_seq_len', 16384)
|
|
|
|
|
| config = NeoMiniConfig()
|
| config.max_seq_len = self.max_context_length
|
| self.model = NeoMini(config)
|
| self.model.load_state_dict(checkpoint['model_state_dict'])
|
| self.model.eval()
|
| self.model = self.model.to(self.device)
|
|
|
|
|
| tokenizer_path = "data/tokenizer"
|
| if Path(tokenizer_path).exists():
|
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
|
| else:
|
| print("Using GPT-2 tokenizer as fallback...")
|
| self.tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
| if self.tokenizer.pad_token is None:
|
| self.tokenizer.pad_token = self.tokenizer.eos_token
|
|
|
| print(f"β
Model loaded successfully!")
|
| print(f"π§ Parameters: {self.model.get_num_params():,}")
|
| print(f"π Context window: {self.max_context_length:,} tokens")
|
| print(f"πΎ Memory: {self.get_memory_usage()}")
|
| return True
|
|
|
| except Exception as e:
|
| print(f"β Error loading model: {e}")
|
| return False
|
|
|
| def format_conversation_context(self):
|
| """Format conversation history for model input"""
|
| if not self.conversation_history:
|
| return ""
|
|
|
| context = "The following is a conversation between a human and an AI assistant. The AI assistant is helpful, harmless, and honest.\n\n"
|
|
|
| for exchange in self.conversation_history:
|
| context += f"Human: {exchange['human']}\n"
|
| context += f"AI: {exchange['ai']}\n\n"
|
|
|
| return context
|
|
|
| def generate_response(self, user_input, num_responses=None):
|
| """Generate AI response(s) to user input"""
|
| if num_responses is None:
|
| num_responses = self.params['num_responses']
|
|
|
|
|
| context = self.format_conversation_context()
|
| full_prompt = context + f"Human: {user_input}\nAI: "
|
|
|
|
|
| input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
|
| prompt_length = input_ids.size(1)
|
|
|
| print(f"π Context: {prompt_length:,}/{self.max_context_length:,} tokens")
|
|
|
| if prompt_length >= self.max_context_length:
|
| print("β οΈ Context too long, trimming conversation history...")
|
| self.trim_conversation_history()
|
| context = self.format_conversation_context()
|
| full_prompt = context + f"Human: {user_input}\nAI: "
|
| input_ids = self.tokenizer.encode(full_prompt, return_tensors="pt").to(self.device)
|
| prompt_length = input_ids.size(1)
|
|
|
|
|
| responses = []
|
| for i in range(num_responses):
|
| print(f"π€ Generating response {i+1}/{num_responses}...")
|
|
|
| with torch.no_grad():
|
| generated = input_ids.clone()
|
| max_new_tokens = min(self.params['max_length'], self.max_context_length - prompt_length)
|
|
|
| for step in range(max_new_tokens):
|
| logits = self.model(generated)
|
| next_token_logits = logits[0, -1, :] / self.params['temperature']
|
|
|
|
|
| if self.params['repetition_penalty'] != 1.0:
|
| for token_id in set(generated[0].tolist()):
|
| if next_token_logits[token_id] < 0:
|
| next_token_logits[token_id] *= self.params['repetition_penalty']
|
| else:
|
| next_token_logits[token_id] /= self.params['repetition_penalty']
|
|
|
|
|
| if self.params['top_k'] > 0:
|
| top_k_logits, _ = torch.topk(next_token_logits, self.params['top_k'])
|
| min_top_k = top_k_logits[-1]
|
| next_token_logits[next_token_logits < min_top_k] = float("-inf")
|
|
|
|
|
| if self.params['top_p'] < 1.0:
|
| sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
|
| cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| sorted_indices_to_remove = cumulative_probs > self.params['top_p']
|
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| sorted_indices_to_remove[..., 0] = 0
|
| indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| next_token_logits[indices_to_remove] = float("-inf")
|
|
|
|
|
| if self.params['do_sample']:
|
| probs = torch.softmax(next_token_logits, dim=-1)
|
| next_token = torch.multinomial(probs, num_samples=1)
|
| else:
|
| next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
|
|
|
| generated = torch.cat([generated, next_token.unsqueeze(0)], dim=1)
|
|
|
|
|
| if next_token.item() == self.tokenizer.eos_token_id:
|
| break
|
|
|
|
|
| if step > 10:
|
| decoded = self.tokenizer.decode(generated[0][prompt_length:], skip_special_tokens=True)
|
| if decoded.strip().endswith(('.', '!', '?', '\n\n')):
|
| break
|
|
|
|
|
| full_response = self.tokenizer.decode(generated[0], skip_special_tokens=True)
|
| ai_response = full_response[len(full_prompt):].strip()
|
|
|
|
|
| if '\nHuman:' in ai_response:
|
| ai_response = ai_response.split('\nHuman:')[0].strip()
|
|
|
| responses.append(ai_response)
|
|
|
| return responses
|
|
|
| def trim_conversation_history(self):
|
| """Remove oldest conversation turns to fit context"""
|
| while len(self.conversation_history) > 1:
|
| self.conversation_history.pop(0)
|
| context = self.format_conversation_context()
|
| if len(self.tokenizer.encode(context)) < self.max_context_length // 2:
|
| break
|
| print(f"π§Ή Trimmed conversation history to {len(self.conversation_history)} turns")
|
|
|
| def update_parameters(self):
|
| """Interactive parameter adjustment"""
|
| print("\nποΈ Current Generation Parameters:")
|
| for key, value in self.params.items():
|
| print(f" {key}: {value}")
|
|
|
| print("\nEnter new values (press Enter to keep current):")
|
|
|
|
|
| temp_input = input(f"Temperature (0.1-2.0, current: {self.params['temperature']}): ").strip()
|
| if temp_input:
|
| try:
|
| self.params['temperature'] = max(0.1, min(2.0, float(temp_input)))
|
| except ValueError:
|
| print("β Invalid temperature, keeping current value")
|
|
|
|
|
| topk_input = input(f"Top-k (0-100, current: {self.params['top_k']}): ").strip()
|
| if topk_input:
|
| try:
|
| self.params['top_k'] = max(0, min(100, int(topk_input)))
|
| except ValueError:
|
| print("β Invalid top-k, keeping current value")
|
|
|
|
|
| topp_input = input(f"Top-p (0.1-1.0, current: {self.params['top_p']}): ").strip()
|
| if topp_input:
|
| try:
|
| self.params['top_p'] = max(0.1, min(1.0, float(topp_input)))
|
| except ValueError:
|
| print("β Invalid top-p, keeping current value")
|
|
|
|
|
| maxlen_input = input(f"Max length (10-500, current: {self.params['max_length']}): ").strip()
|
| if maxlen_input:
|
| try:
|
| self.params['max_length'] = max(10, min(500, int(maxlen_input)))
|
| except ValueError:
|
| print("β Invalid max length, keeping current value")
|
|
|
|
|
| num_resp = input(f"Number of responses (1-3, current: {self.params['num_responses']}): ").strip()
|
| if num_resp:
|
| try:
|
| self.params['num_responses'] = max(1, min(3, int(num_resp)))
|
| except ValueError:
|
| print("β Invalid number, keeping current value")
|
|
|
| print("β
Parameters updated!")
|
|
|
| def save_conversation(self):
|
| """Save conversation to file"""
|
| if not self.conversation_history:
|
| print("β No conversation to save")
|
| return
|
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| filename = f"conversation_{timestamp}.json"
|
|
|
| conversation_data = {
|
| 'timestamp': timestamp,
|
| 'model_info': {
|
| 'max_context': self.max_context_length,
|
| 'parameters': self.model.get_num_params()
|
| },
|
| 'generation_params': self.params,
|
| 'conversation': self.conversation_history
|
| }
|
|
|
| with open(filename, 'w', encoding='utf-8') as f:
|
| json.dump(conversation_data, f, indent=2, ensure_ascii=False)
|
|
|
| print(f"πΎ Conversation saved to {filename}")
|
|
|
| def load_conversation(self, filename):
|
| """Load conversation from file"""
|
| try:
|
| with open(filename, 'r', encoding='utf-8') as f:
|
| conversation_data = json.load(f)
|
|
|
| self.conversation_history = conversation_data['conversation']
|
| print(f"π Loaded conversation with {len(self.conversation_history)} turns")
|
|
|
| except Exception as e:
|
| print(f"β Error loading conversation: {e}")
|
|
|
| def show_help(self):
|
| """Show available commands"""
|
| print("\nπ§ Available Commands:")
|
| print(" /help - Show this help message")
|
| print(" /params - Adjust generation parameters")
|
| print(" /clear - Clear conversation history")
|
| print(" /save - Save current conversation")
|
| print(" /load <file> - Load conversation from file")
|
| print(" /memory - Show GPU memory usage")
|
| print(" /context - Show current context usage")
|
| print(" /multi <n> - Generate n responses to next input")
|
| print(" /exit - Exit the chat")
|
| print(" /quit - Exit the chat")
|
|
|
| def run(self):
|
| """Main chat loop"""
|
| if not self.model or not self.tokenizer:
|
| print("β Model not loaded. Exiting.")
|
| return
|
|
|
| print(f"\n㪠Chat started! Context window: {self.max_context_length:,} tokens")
|
| print("Type /help for commands, /exit to quit")
|
| print("-" * 60)
|
|
|
| while True:
|
| try:
|
|
|
| user_input = input("\nπ€ You: ").strip()
|
|
|
| if not user_input:
|
| continue
|
|
|
|
|
| if user_input.startswith('/'):
|
| command = user_input.lower()
|
|
|
| if command in ['/exit', '/quit']:
|
| print("π Goodbye!")
|
| break
|
| elif command == '/help':
|
| self.show_help()
|
| elif command == '/params':
|
| self.update_parameters()
|
| elif command == '/clear':
|
| self.conversation_history = []
|
| self.clear_gpu_cache()
|
| print("π§Ή Conversation history cleared")
|
| elif command == '/save':
|
| self.save_conversation()
|
| elif command.startswith('/load '):
|
| filename = command[6:].strip()
|
| self.load_conversation(filename)
|
| elif command == '/memory':
|
| print(f"πΎ GPU Memory: {self.get_memory_usage()}")
|
| elif command == '/context':
|
| context_length = len(self.tokenizer.encode(self.format_conversation_context()))
|
| print(f"π Current context: {context_length:,}/{self.max_context_length:,} tokens")
|
| elif command.startswith('/multi '):
|
| try:
|
| num = int(command[7:].strip())
|
| self.params['num_responses'] = max(1, min(3, num))
|
| print(f"π― Next response will generate {self.params['num_responses']} options")
|
| except ValueError:
|
| print("β Invalid number format")
|
| else:
|
| print("β Unknown command. Type /help for available commands.")
|
| continue
|
|
|
|
|
| start_time = time.time()
|
| responses = self.generate_response(user_input)
|
| generation_time = time.time() - start_time
|
|
|
|
|
| if len(responses) == 1:
|
| print(f"\nπ€ AI: {responses[0]}")
|
| chosen_response = responses[0]
|
| else:
|
| print(f"\nπ€ AI generated {len(responses)} responses:")
|
| for i, response in enumerate(responses, 1):
|
| print(f"\n[{i}] {response}")
|
|
|
| while True:
|
| choice = input(f"\nChoose response (1-{len(responses)}, Enter for 1): ").strip()
|
| if not choice:
|
| choice = "1"
|
| try:
|
| choice_idx = int(choice) - 1
|
| if 0 <= choice_idx < len(responses):
|
| chosen_response = responses[choice_idx]
|
| break
|
| else:
|
| print(f"β Invalid choice. Enter 1-{len(responses)}")
|
| except ValueError:
|
| print("β Invalid input. Enter a number.")
|
|
|
|
|
| self.conversation_history.append({
|
| 'human': user_input,
|
| 'ai': chosen_response,
|
| 'timestamp': datetime.now().isoformat(),
|
| 'generation_time': round(generation_time, 2)
|
| })
|
|
|
|
|
| if self.params['num_responses'] != 1:
|
| self.params['num_responses'] = 1
|
|
|
| print(f"β±οΈ Generated in {generation_time:.2f}s | πΎ {self.get_memory_usage()}")
|
|
|
| except KeyboardInterrupt:
|
| print("\n\nπ Chat interrupted. Goodbye!")
|
| break
|
| except Exception as e:
|
| print(f"\nβ Error: {e}")
|
| self.clear_gpu_cache()
|
|
|
| def main():
|
|
|
| import sys
|
| checkpoint_path = "checkpoints/extended_context_model.pt"
|
| if len(sys.argv) > 1:
|
| checkpoint_path = sys.argv[1]
|
|
|
| chat = InteractiveChat(checkpoint_path)
|
| chat.run()
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|