| import json |
| import os |
| from transformers import PreTrainedTokenizer |
|
|
| class CharacterTokenizer(PreTrainedTokenizer): |
| """ |
| Character-level tokenizer for OCR tasks that follows HuggingFace conventions. |
| Each character becomes a separate token, but decoding produces continuous text. |
| """ |
|
|
| def __init__( |
| self, |
| vocab_file=None, |
| unk_token="<unk>", |
| pad_token="<pad>", |
| bos_token="<s>", |
| eos_token="</s>", |
| max_length=256, |
| **kwargs |
| ): |
| if vocab_file is None or not os.path.isfile(vocab_file): |
| raise ValueError("`vocab_file` must be provided or exist.") |
|
|
| |
| with open(vocab_file, "r", encoding="utf-8") as f: |
| self.token_to_id = json.load(f) |
| self.id_to_token = {v: k for k, v in self.token_to_id.items()} |
|
|
| self.max_length = max_length |
|
|
| |
| super().__init__( |
| unk_token=unk_token, |
| pad_token=pad_token, |
| bos_token=bos_token, |
| eos_token=eos_token, |
| **kwargs |
| ) |
|
|
| @classmethod |
| def register_for_auto_class(cls, auto_class="AutoTokenizer"): |
| """Register this tokenizer for AutoTokenizer""" |
| return cls |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): |
| """Load tokenizer from a directory or Hub""" |
| |
| if os.path.isdir(pretrained_model_name_or_path): |
| vocab_file = os.path.join(pretrained_model_name_or_path, "vocab.json") |
| else: |
| |
| from huggingface_hub import hf_hub_download |
| vocab_file = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="vocab.json" |
| ) |
|
|
| |
| try: |
| if os.path.isdir(pretrained_model_name_or_path): |
| config_file = os.path.join(pretrained_model_name_or_path, "tokenizer_config.json") |
| else: |
| config_file = hf_hub_download( |
| repo_id=pretrained_model_name_or_path, |
| filename="tokenizer_config.json" |
| ) |
|
|
| if os.path.exists(config_file): |
| with open(config_file, "r") as f: |
| config = json.load(f) |
| kwargs.update(config) |
| except: |
| pass |
|
|
| |
| kwargs.pop('vocab_file', None) |
|
|
| return cls(vocab_file=vocab_file, **kwargs) |
|
|
| @property |
| def vocab_size(self): |
| return len(self.token_to_id) |
|
|
| def get_vocab(self): |
| return self.token_to_id |
|
|
| def _tokenize(self, text): |
| """Tokenize text into individual characters""" |
| return list(text) |
|
|
| def _convert_token_to_id(self, token): |
| """Convert a token (character) to its ID""" |
| return self.token_to_id.get(token, self.unk_token_id) |
|
|
| def _convert_id_to_token(self, index): |
| """Convert an ID to its token (character)""" |
| return self.id_to_token.get(index, self.unk_token) |
|
|
| def convert_tokens_to_string(self, tokens): |
| """ |
| Convert a sequence of tokens to a single string. |
| This is the KEY method that HuggingFace uses for decoding! |
| For character-level tokenization, we join without spaces. |
| """ |
| |
| filtered_tokens = [] |
| for token in tokens: |
| if token not in {self.pad_token, self.bos_token, self.eos_token, self.unk_token}: |
| filtered_tokens.append(token) |
| |
| |
| return ''.join(filtered_tokens) |
|
|
| def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True, **kwargs): |
| """ |
| Override decode to ensure proper character-level decoding. |
| This follows HuggingFace conventions but handles character-level properly. |
| """ |
| |
| if hasattr(token_ids, 'tolist'): |
| token_ids = token_ids.tolist() |
| |
| |
| tokens = [self._convert_id_to_token(id) for id in token_ids] |
| |
| |
| if skip_special_tokens: |
| tokens = [token for token in tokens if token not in { |
| self.pad_token, self.bos_token, self.eos_token, self.unk_token |
| }] |
| |
| |
| text = self.convert_tokens_to_string(tokens) |
| |
| |
| |
| return text |
|
|
| def batch_decode(self, sequences, skip_special_tokens=False, clean_up_tokenization_spaces=True, **kwargs): |
| """ |
| Batch decode following HuggingFace conventions |
| """ |
| return [ |
| self.decode(seq, skip_special_tokens=skip_special_tokens, clean_up_tokenization_spaces=clean_up_tokenization_spaces, **kwargs) |
| for seq in sequences |
| ] |
|
|
| def save_vocabulary(self, save_directory, filename_prefix=None): |
| """Save vocabulary following HuggingFace conventions""" |
| os.makedirs(save_directory, exist_ok=True) |
|
|
| vocab_path = os.path.join(save_directory, "vocab.json") |
| with open(vocab_path, "w", encoding="utf-8") as f: |
| json.dump(self.token_to_id, f, ensure_ascii=False, indent=2) |
|
|
| config_path = os.path.join(save_directory, "tokenizer_config.json") |
| with open(config_path, "w", encoding="utf-8") as f: |
| json.dump({ |
| "tokenizer_class": "CharacterTokenizer", |
| "auto_map": { |
| "AutoTokenizer": ["tokenizer.CharacterTokenizer", None] |
| }, |
| "bos_token": self.bos_token, |
| "eos_token": self.eos_token, |
| "unk_token": self.unk_token, |
| "pad_token": self.pad_token, |
| "vocab_file": "vocab.json", |
| "clean_up_tokenization_spaces": False, |
| }, f, indent=2) |
|
|
| return (vocab_path,) |
|
|
| def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None): |
| """Build inputs with special tokens following HuggingFace conventions""" |
| if token_ids_1 is None: |
| return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] |
| else: |
| return ( |
| [self.bos_token_id] |
| + token_ids_0 |
| + [self.eos_token_id] |
| + token_ids_1 |
| + [self.eos_token_id] |
| ) |
|
|
| def create_token_type_ids_from_sequences(self, token_ids_0, token_ids_1=None): |
| """Create token type IDs following HuggingFace conventions""" |
| return [0] * len( |
| self.build_inputs_with_special_tokens(token_ids_0, token_ids_1) |
| ) |
|
|