Source code for eli5.lime.textutils

# -*- coding: utf-8 -*-
"""
Utilities for text generation.
"""
from __future__ import absolute_import
import re
import math
from typing import List, Tuple, Union, Optional

import numpy as np
from sklearn.utils import check_random_state

from eli5.utils import indices_to_bool_mask, vstack


# the same as scikit-learn token pattern, but allows single-char tokens
DEFAULT_TOKEN_PATTERN = r'(?u)\b\w+\b'

# non-whitespace chars
CHAR_TOKEN_PATTERN = r'[^\s]'


[docs]def generate_samples(text, # type: TokenizedText n_samples=500, # type: int bow=True, # type: bool random_state=None, replacement='', # type: str min_replace=1, # type: Union[int, float] max_replace=1.0, # type: Union[int, float] group_size=1, # type: int ): # type: (...) -> Tuple[List[str], np.ndarray, np.ndarray] """ Return ``n_samples`` changed versions of text (with some words removed), along with distances between the original text and a generated examples. If ``bow=False``, all tokens are considered unique (i.e. token position matters). """ kwargs = dict( n_samples=n_samples, replacement=replacement, random_state=random_state, min_replace=min_replace, max_replace=max_replace, ) if bow: num_tokens = len(text.vocab) res = text.replace_random_tokens_bow(**kwargs) else: num_tokens = len(text.tokens) res = text.replace_random_tokens(group_size=group_size, **kwargs) texts, num_removed_vec, masks = zip(*res) similarity = cosine_similarity_vec(num_tokens, num_removed_vec) return texts, similarity, vstack(masks)
[docs]def cosine_similarity_vec(num_tokens, num_removed_vec): """ Return cosine similarity between a binary vector with all ones of length ``num_tokens`` and vectors of the same length with ``num_removed_vec`` elements set to zero. """ remaining = -np.array(num_removed_vec) + num_tokens return remaining / (np.sqrt(num_tokens + 1e-6) * np.sqrt(remaining + 1e-6))
class TokenizedText(object): def __init__(self, text, token_pattern=DEFAULT_TOKEN_PATTERN): # type: (str, str) -> None self.text = text self.split = SplitResult.fromtext(text, token_pattern) self._vocab = None # type: Optional[List[str]] def replace_random_tokens(self, n_samples, # type: int replacement='', # type: str random_state=None, min_replace=1, # type: Union[int, float] max_replace=1.0, # type: Union[int, float] group_size=1 # type: int ): # type: (...) -> List[Tuple[str, int, np.ndarray]] """ Return a list of ``(text, replaced_count, mask)`` tuples with n_samples versions of text with some words replaced. By default words are replaced with '', i.e. removed. """ n_tokens = len(self.tokens) indices = np.arange(n_tokens) if not n_tokens: nomask = np.array([], dtype=int) return [('', 0, nomask)] * n_samples min_replace, max_replace = self._get_min_max(min_replace, max_replace, n_tokens) rng = check_random_state(random_state) replace_sizes = rng.randint(low=min_replace, high=max_replace + 1, size=n_samples) res = [] for num_to_replace in replace_sizes: idx_to_replace = rng.choice(indices, num_to_replace, replace=False) idx_to_replace = np.array([idx_to_replace] + [ idx_to_replace + shift for shift in range(1, group_size) ]).ravel() padded_size = n_tokens + group_size - 1 mask = indices_to_bool_mask(idx_to_replace, padded_size)[:n_tokens] s = self.split.masked(mask, replacement) res.append((s.text, num_to_replace, mask)) return res def replace_random_tokens_bow(self, n_samples, # type: int replacement='', # type: str random_state=None, min_replace=1, # type: Union[int, float] max_replace=1.0, # type: Union[int, float] ): # type: (...) -> List[Tuple[str, int, np.ndarray]] """ Return a list of ``(text, replaced_words_count, mask)`` tuples with n_samples versions of text with some words replaced. If a word is replaced, all duplicate words are also replaced from the text. By default words are replaced with '', i.e. removed. """ if not self.vocab: nomask = np.array([], dtype=int) return [('', 0, nomask)] * n_samples min_replace, max_replace = self._get_min_max(min_replace, max_replace, len(self.vocab)) rng = check_random_state(random_state) replace_sizes = rng.randint(low=min_replace, high=max_replace + 1, size=n_samples) res = [] for num_to_replace in replace_sizes: tokens_to_replace = set(rng.choice(self.vocab, num_to_replace, replace=False)) idx_to_replace = [idx for idx, token in enumerate(self.tokens) if token in tokens_to_replace] mask = indices_to_bool_mask(idx_to_replace, len(self.tokens)) s = self.split.masked(idx_to_replace, replacement) res.append((s.text, num_to_replace, mask)) return res def _get_min_max(self, min_replace, # type: Union[int, float] max_replace, # type: Union[int, float] hard_maximum # type: int ): # type: (...) -> Tuple[int, int] if isinstance(min_replace, float): min_replace = int(math.floor(hard_maximum * min_replace)) or 1 if isinstance(max_replace, float): max_replace = int(math.ceil(hard_maximum * max_replace)) else: max_replace = min(max_replace, hard_maximum) return min_replace, max_replace @property def vocab(self): # type: () -> List[str] if self._vocab is None: self._vocab = sorted(set(self.tokens)) return self._vocab @property def tokens(self): return self.split.tokens @property def spans_and_tokens(self): return list(zip(self.split.token_spans, self.split.tokens)) class SplitResult(object): def __init__(self, parts): self.parts = np.array(parts, ndmin=1) self.lenghts = np.array([len(p) for p in parts]) self.starts = self.lenghts.cumsum() @classmethod def fromtext(cls, text, token_pattern=DEFAULT_TOKEN_PATTERN): # type: (str, str) -> SplitResult token_pattern = u"(%s)" % token_pattern parts = re.split(token_pattern, text) return cls(parts) @property def separators(self): return self.parts[::2] @property def tokens(self): return self.parts[1::2] @property def token_spans(self): # type: () -> List[Tuple[int, int]] return list(zip(self.starts[::2], self.starts[1::2])) def copy(self): # type: () -> SplitResult return self.__class__(self.parts.copy()) def masked(self, invmask, replacement=''): # type: (Union[np.ndarray, List[int]], str) -> SplitResult s = self.copy() s.tokens[invmask] = replacement return s @property def text(self): # type: () -> str return "".join(self.parts)