| | import torch |
| | from torch import nn |
| | from tqdm import tqdm |
| | from torch.nn import functional as F |
| | from transformers import ( |
| | set_seed, pipeline, AutoTokenizer, AutoModelForCausalLM |
| | ) |
| |
|
| | EMBEDDING = """ |
| | You are a helpful AI assistant. Your task is to analyze input text and create a high-quality semantic vector embedding, which represents key concepts, relationships, and semantic meaning. |
| | """ |
| | GENERATION = """ |
| | You are a helpful AI assistant. Your task is to enrich user input for more effective embedding representation by adding semantic depth. |
| | |
| | For each input, briefly enhance the content by: |
| | 1. Identifying core concepts and their relationships. |
| | 2. Including key terminology with essential definitions. |
| | 3. Adding contextually relevant synonyms and related terms. |
| | 4. Connecting to related topics and common applications without excessive elaboration. |
| | |
| | To represent the final embedding, you MUST end every response with <|embed_token|>. |
| | """ |
| |
|
| |
|
| | class SearchR3(nn.Module): |
| | def __init__(self, |
| | path: str, |
| | max_length: int, |
| | batch_size: int): |
| | nn.Module.__init__(self) |
| | |
| | self.model = AutoModelForCausalLM.from_pretrained( |
| | path, torch_dtype='auto', device_map='auto' |
| | ) |
| | self.tokenizer = AutoTokenizer.from_pretrained( |
| | path, truncation_side='left', padding_side='left' |
| | ) |
| | self.embed_token = self.tokenizer.encode('<|embed_token|>')[0] |
| | self.max_length = max_length |
| | self.batch_size = batch_size |
| |
|
| | @property |
| | def device(self): |
| | return next(self.model.parameters()).device |
| |
|
| | @torch.no_grad() |
| | def generate(self, batch: list[str]): |
| | if not isinstance(batch, (list, tuple)): |
| | raise ValueError('batch type is incorrect') |
| | if any(not isinstance(v, str) for v in batch): |
| | raise ValueError('batch item type is incorrect') |
| |
|
| | |
| | if len(batch) > self.batch_size: |
| | outputs = [] |
| | for i in tqdm( |
| | range(0, len(batch), self.batch_size) |
| | ): |
| | outputs.extend( |
| | self.generate( |
| | batch[i:i + self.batch_size] |
| | ) |
| | ) |
| | return outputs |
| |
|
| | |
| | messages = [ |
| | [ |
| | {'role': 'system', 'content': GENERATION.strip()}, |
| | {'role': 'user', 'content': item} |
| | ] |
| | for item in batch |
| | ] |
| | context = self.tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True |
| | ) |
| | inputs = self.tokenizer( |
| | context, padding='longest', truncation=True, |
| | return_tensors='pt', max_length=self.max_length // 2 |
| | ) |
| | prompt_length = inputs['input_ids'].size(-1) |
| |
|
| | |
| | self.model.eval() |
| | outputs = self.model.generate( |
| | **inputs.to(device=self.device), |
| | max_new_tokens=self.max_length - prompt_length |
| | ) |
| | outputs = self.tokenizer.batch_decode( |
| | outputs[:, prompt_length:], skip_special_tokens=False |
| | ) |
| |
|
| | |
| | for special_token in self.tokenizer.all_special_tokens: |
| | if special_token == '<|embed_token|>': |
| | continue |
| | outputs = [ |
| | item.replace(special_token, '') for item in outputs |
| | ] |
| | messages = [ |
| | item + [ |
| | {'role': 'assistant', 'content': outputs[i].strip()} |
| | ] |
| | for i, item in enumerate(messages) |
| | ] |
| | return messages |
| |
|
| | def format(self, batch: list[str]): |
| | if any(not isinstance(v, str) for v in batch): |
| | raise RuntimeError('batch type is incorrect') |
| | return [ |
| | [ |
| | {'role': 'system', 'content': EMBEDDING.strip()}, |
| | {'role': 'user', 'content': item}, |
| | {'role': 'assistant', 'content': 'The embedding is: <|embed_token|>'} |
| | ] |
| | for item in batch |
| | ] |
| |
|
| | @torch.no_grad() |
| | def encode(self, batch: list[any]): |
| | if not isinstance(batch, (list, tuple)): |
| | raise ValueError('batch type is incorrect') |
| |
|
| | |
| | if len(batch) > self.batch_size: |
| | outputs = [ |
| | self.encode( |
| | batch[i:i + self.batch_size] |
| | ) |
| | for i in tqdm( |
| | range(0, len(batch), self.batch_size) |
| | ) |
| | ] |
| | return torch.cat(outputs, dim=0) |
| |
|
| | |
| | if all(isinstance(v, str) for v in batch): |
| | batch = self.format(batch=batch) |
| |
|
| | |
| | if any( |
| | m[-1]['role'] != 'assistant' for m in batch |
| | ): |
| | raise RuntimeError('unexpected role') |
| | if any( |
| | m[-2]['role'] != 'user' for m in batch |
| | ): |
| | raise RuntimeError('unexpected role') |
| |
|
| | |
| | batch = [ |
| | m if '<|embed_token|>' in m[-1]['content'] |
| | else self.format([m[-2]['content']])[0] |
| | for m in batch |
| | ] |
| | if any( |
| | '<|embed_token|>' not in m[-1]['content'] for m in batch |
| | ): |
| | raise RuntimeError('unexpected embed token') |
| |
|
| | |
| | context = self.tokenizer.apply_chat_template( |
| | batch, tokenize=False, add_generation_prompt=False |
| | ) |
| | inputs = self.tokenizer( |
| | context, padding='longest', truncation=True, |
| | return_tensors='pt', max_length=self.max_length |
| | ) |
| |
|
| | |
| | self.model.eval() |
| | outputs = self.model( |
| | **inputs.to(device=self.device), |
| | return_dict=True, output_hidden_states=True |
| | ) |
| | hidden_state = outputs['hidden_states'][-1] |
| |
|
| | |
| | length = inputs['input_ids'].size(-1) |
| | valid_mask = torch.arange(length, device=self.device) |
| | valid_mask = torch.where( |
| | valid_mask.unsqueeze(0) > length - 5, True, False |
| | ) |
| | embed_mask = torch.where( |
| | inputs['input_ids'] == self.embed_token, True, False |
| | ) |
| | embed_mask = embed_mask.logical_and(valid_mask) |
| | return F.normalize( |
| | hidden_state[embed_mask].cpu().float(), dim=-1 |
| | ) |
| |
|
| |
|
| | def main(): |
| | |
| | set_seed(42) |
| | from pprint import pprint |
| |
|
| | |
| | generator = pipeline( |
| | task='text-generation', |
| | model='ytgui/Search-R3.0-Small', |
| | torch_dtype='auto', device_map='auto' |
| | ) |
| | messages = [ |
| | {"role": 'user', 'content': 'Who are you?'}, |
| | ] |
| | response = generator(messages, max_new_tokens=256) |
| | pprint(response) |
| |
|
| | |
| | model = SearchR3( |
| | 'ytgui/Search-R3.0-Small', max_length=1024, batch_size=8 |
| | ) |
| | reasoning = model.generate( |
| | batch=['what python library is useful for data analysis?'] |
| | ) |
| | pprint(reasoning) |
| |
|
| | |
| | documents = [ |
| | 'pandas is a fast, powerful, flexible and easy to use open source data analysis and manipulation tool, built on top of the Python programming language.', |
| | 'The giant panda (Ailuropoda melanoleuca), also known as the panda bear or simply panda, is a bear species endemic to China. It is characterised by its white coat with black patches around the eyes, ears, legs and shoulders.', |
| | ] |
| | E_d = model.encode(batch=documents) |
| | E_q = model.encode(batch=reasoning) |
| | print('distance:', torch.cdist(E_q, E_d, p=2.0)) |
| |
|
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|