| """ |
| Data Preprocessing for Memory Routing Training |
| |
| This script converts synthetic JSONL conversations to Tinker-compatible |
| types.Datum objects for supervised fine-tuning. |
| |
| Per Tinker docs (rendering.mdx): |
| - Use renderer.build_supervised_example() to get tokens and weights |
| - Weights indicate which tokens to train on (1.0 for completion, 0.0 for prompt) |
| - Target tokens are shifted by 1 (predicting next token) |
| |
| Per PRD Section 6.6: |
| - Validate datum length <= 4096 |
| - Ensure non-zero weights |
| - Verify token IDs are within vocab range |
| """ |
|
|
| import json |
| import os |
| from typing import List, Dict, Any, Tuple |
| from dataclasses import dataclass |
|
|
| |
| |
| |
|
|
| MODEL_NAME = "meta-llama/Llama-3.1-8B" |
| RENDERER_NAME = "llama3" |
| MAX_SEQUENCE_LENGTH = 4096 |
|
|
| |
| VALID_CATEGORIES = { |
| "company.brand_core", |
| "company.strategic_signatures", |
| "company.knowledge_artifacts", |
| "company.business_priorities", |
| "company.tools_config", |
| "company.performance_context", |
| "user.communication_style", |
| "user.strategic_approach", |
| "user.role_context", |
| "user.workflow_patterns", |
| "user.session_history", |
| "user.interaction_preferences", |
| "none" |
| } |
|
|
| @dataclass |
| class PreprocessingStats: |
| total_examples: int = 0 |
| valid_examples: int = 0 |
| skipped_too_long: int = 0 |
| skipped_zero_weights: int = 0 |
| skipped_invalid_tokens: int = 0 |
| skipped_invalid_categories: int = 0 |
|
|
|
|
| def build_routing_prompt(conversation: List[Dict[str, str]], categories: List[str]) -> List[Dict[str, str]]: |
| """ |
| Build the full conversation for training, including: |
| 1. System prompt with taxonomy |
| 2. User message with conversation |
| 3. Assistant response with categories |
| |
| Per PRD Section 6 - Student Prompt format. |
| """ |
| |
| system_content = """You route marketing conversations into structured memory categories. |
| |
| Available categories: |
| - company.brand_core: Voice, values, positioning, identity anchors (Long >1y) |
| - company.strategic_signatures: Decision frameworks, strategic heuristics (Long >1y) |
| - company.knowledge_artifacts: Docs, style guides, playbooks (Long >1y) |
| - company.business_priorities: Quarterly/seasonal goals, active campaigns (Short <3m) |
| - company.tools_config: Integrations, API keys, workflow settings (Medium ~6m) |
| - company.performance_context: Campaign metrics, retrospectives, learnings (Rolling ~6m) |
| - user.communication_style: Tone, verbosity, format expectations (Long >1y) |
| - user.strategic_approach: Personal priorities, success definitions (Long >1y) |
| - user.role_context: Title, scope, decision authority (Medium ~1y) |
| - user.workflow_patterns: Review cadence, collaboration norms (Medium ~1y) |
| - user.session_history: Immediate context, recent asks (Short <2w) |
| - user.interaction_preferences: Coaching style, feedback expectations (Evolving) |
| - none: Irrelevant, vague, or transactional content |
| |
| Respond with comma-separated categories. Use 'none' only if no other category applies.""" |
|
|
| |
| conversation_text = "" |
| for turn in conversation: |
| |
| if isinstance(turn, str): |
| conversation_text += f"UNKNOWN: {turn}\n" |
| continue |
| if not isinstance(turn, dict): |
| continue |
| role = turn.get("role", "unknown") |
| content = turn.get("content", "") |
| conversation_text += f"{role.upper()}: {content}\n" |
| |
| user_content = f"Conversation:\n{conversation_text.strip()}\n\nWhat memory categories apply?" |
| |
| |
| assistant_content = ", ".join(categories) |
| |
| return [ |
| {"role": "system", "content": system_content}, |
| {"role": "user", "content": user_content}, |
| {"role": "assistant", "content": assistant_content} |
| ] |
|
|
|
|
| def load_synthetic_data(filepath: str) -> List[Dict[str, Any]]: |
| """Load synthetic data from JSONL file.""" |
| data = [] |
| with open(filepath, "r") as f: |
| for line in f: |
| if line.strip(): |
| item = json.loads(line) |
| data.append(item) |
| return data |
|
|
|
|
| def validate_categories(categories: List[str]) -> bool: |
| """Validate that all categories are in the taxonomy.""" |
| return all(cat in VALID_CATEGORIES for cat in categories) |
|
|
|
|
| def preprocess_example_mock(example: Dict[str, Any], stats: PreprocessingStats) -> Dict[str, Any] | None: |
| """ |
| Mock preprocessing that validates structure without Tinker. |
| Returns a dict representation of what would become a Datum. |
| |
| Use this for testing without Tinker installed. |
| """ |
| conversation = example.get("conversation", []) |
| labels = example.get("labels", {}) |
| categories = labels.get("categories", []) |
| |
| |
| if not validate_categories(categories): |
| stats.skipped_invalid_categories += 1 |
| return None |
| |
| |
| training_messages = build_routing_prompt(conversation, categories) |
| |
| |
| total_chars = sum(len(m["content"]) for m in training_messages) |
| estimated_tokens = total_chars // 4 |
| |
| if estimated_tokens > MAX_SEQUENCE_LENGTH: |
| stats.skipped_too_long += 1 |
| return None |
| |
| stats.valid_examples += 1 |
| |
| return { |
| "messages": training_messages, |
| "categories": categories, |
| "estimated_tokens": estimated_tokens, |
| "scenario_id": example.get("scenario_id", "unknown") |
| } |
|
|
|
|
| def preprocess_with_tinker(example: Dict[str, Any], renderer, tokenizer, vocab_size: int, stats: PreprocessingStats): |
| """ |
| Full preprocessing with Tinker renderer. |
| |
| Per Tinker docs (rendering.mdx): |
| - build_supervised_example returns (tokens, weights) |
| - weights=1.0 for completion tokens, weights=0.0 for prompt tokens |
| |
| Per Tinker docs (training-sampling.mdx): |
| - input_tokens = tokens[:-1] |
| - target_tokens = tokens[1:] # Shifted for next-token prediction |
| - weights = weights[1:] |
| """ |
| from tinker import types |
| |
| conversation = example.get("conversation", []) |
| labels = example.get("labels", {}) |
| categories = labels.get("categories", []) |
| |
| |
| if not validate_categories(categories): |
| stats.skipped_invalid_categories += 1 |
| return None |
| |
| |
| training_messages = build_routing_prompt(conversation, categories) |
| |
| |
| |
| tokens, weights = renderer.build_supervised_example(training_messages) |
| |
| |
| if len(tokens) > MAX_SEQUENCE_LENGTH: |
| stats.skipped_too_long += 1 |
| return None |
| |
| |
| |
| input_tokens = tokens[:-1] |
| target_tokens = tokens[1:] |
| loss_weights = weights[1:] |
| |
| |
| if sum(loss_weights) == 0: |
| stats.skipped_zero_weights += 1 |
| return None |
| |
| |
| if not all(0 <= t < vocab_size for t in target_tokens): |
| stats.skipped_invalid_tokens += 1 |
| return None |
| |
| |
| |
| datum = types.Datum( |
| model_input=types.ModelInput.from_ints(input_tokens), |
| loss_fn_inputs=dict( |
| target_tokens=target_tokens, |
| weights=loss_weights |
| ) |
| ) |
| |
| stats.valid_examples += 1 |
| return datum |
|
|
|
|
| def preprocess_dataset( |
| input_path: str, |
| output_dir: str, |
| use_tinker: bool = False, |
| train_split: float = 0.8 |
| ) -> Tuple[PreprocessingStats, str, str]: |
| """ |
| Preprocess the full dataset. |
| |
| Args: |
| input_path: Path to training_dataset_1000.jsonl |
| output_dir: Directory to save processed data |
| use_tinker: Whether to use actual Tinker (requires installation) |
| train_split: Fraction for training (rest is test) |
| |
| Returns: |
| stats, train_path, test_path |
| """ |
| os.makedirs(output_dir, exist_ok=True) |
| |
| |
| print(f"Loading data from {input_path}...") |
| raw_data = load_synthetic_data(input_path) |
| print(f"Loaded {len(raw_data)} examples") |
| |
| stats = PreprocessingStats(total_examples=len(raw_data)) |
| |
| if use_tinker: |
| |
| from tinker_cookbook import renderers, tokenizer_utils |
| |
| print(f"Initializing tokenizer for {MODEL_NAME}...") |
| tokenizer = tokenizer_utils.get_tokenizer(MODEL_NAME) |
| renderer = renderers.get_renderer(name=RENDERER_NAME, tokenizer=tokenizer) |
| vocab_size = len(tokenizer) |
| print(f"Vocab size: {vocab_size}") |
| |
| processed_data = [] |
| for i, example in enumerate(raw_data): |
| if i % 100 == 0: |
| print(f"Processing {i}/{len(raw_data)}...") |
| datum = preprocess_with_tinker(example, renderer, tokenizer, vocab_size, stats) |
| if datum is not None: |
| processed_data.append(datum) |
| else: |
| |
| print("Running mock preprocessing (no Tinker)...") |
| processed_data = [] |
| for i, example in enumerate(raw_data): |
| if i % 100 == 0: |
| print(f"Processing {i}/{len(raw_data)}...") |
| result = preprocess_example_mock(example, stats) |
| if result is not None: |
| processed_data.append(result) |
| |
| |
| split_idx = int(len(processed_data) * train_split) |
| train_data = processed_data[:split_idx] |
| test_data = processed_data[split_idx:] |
| |
| |
| train_path = os.path.join(output_dir, "train_data.json") |
| test_path = os.path.join(output_dir, "test_data.json") |
| |
| with open(train_path, "w") as f: |
| json.dump([d if isinstance(d, dict) else d.model_dump() for d in train_data], f) |
| |
| with open(test_path, "w") as f: |
| json.dump([d if isinstance(d, dict) else d.model_dump() for d in test_data], f) |
| |
| print(f"\n=== Preprocessing Complete ===") |
| print(f"Total examples: {stats.total_examples}") |
| print(f"Valid examples: {stats.valid_examples}") |
| print(f"Skipped (too long): {stats.skipped_too_long}") |
| print(f"Skipped (zero weights): {stats.skipped_zero_weights}") |
| print(f"Skipped (invalid tokens): {stats.skipped_invalid_tokens}") |
| print(f"Skipped (invalid categories): {stats.skipped_invalid_categories}") |
| print(f"\nTrain set: {len(train_data)} examples") |
| print(f"Test set: {len(test_data)} examples") |
| print(f"\nSaved to:") |
| print(f" Train: {train_path}") |
| print(f" Test: {test_path}") |
| |
| return stats, train_path, test_path |
|
|
|
|
| if __name__ == "__main__": |
| import sys |
| |
| input_path = sys.argv[1] if len(sys.argv) > 1 else "synthetic_data/training_dataset_1000.jsonl" |
| output_dir = sys.argv[2] if len(sys.argv) > 2 else "training/processed_data" |
| use_tinker = "--tinker" in sys.argv |
| |
| preprocess_dataset(input_path, output_dir, use_tinker=use_tinker) |
|
|
|
|