• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

deepset-ai / haystack / 10093246889

25 Jul 2024 11:25AM UTC coverage: 90.145%. Remained the same
10093246889

Pull #8060

github

web-flow
Merge a3643314b into de728b487
Pull Request #8060: Docs: Standardize and improve docstrings

6860 of 7610 relevant lines covered (90.14%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.87
haystack/components/embedders/sentence_transformers_text_embedder.py
1
# SPDX-FileCopyrightText: 2022-present deepset GmbH <info@deepset.ai>
2
#
3
# SPDX-License-Identifier: Apache-2.0
4

5
from typing import Any, Dict, List, Optional
1✔
6

7
from haystack import component, default_from_dict, default_to_dict
1✔
8
from haystack.components.embedders.backends.sentence_transformers_backend import (
1✔
9
    _SentenceTransformersEmbeddingBackendFactory,
10
)
11
from haystack.utils import ComponentDevice, Secret, deserialize_secrets_inplace
1✔
12

13

14
@component
1✔
15
class SentenceTransformersTextEmbedder:
1✔
16
    """
17
    Embeds strings using Sentence Transformers models.
18

19
    You can use it to embed user query and send it to an embedding retriever.
20

21
    Usage example:
22
    ```python
23
    from haystack.components.embedders import SentenceTransformersTextEmbedder
24

25
    text_to_embed = "I love pizza!"
26

27
    text_embedder = SentenceTransformersTextEmbedder()
28
    text_embedder.warm_up()
29

30
    print(text_embedder.run(text_to_embed))
31

32
    # {'embedding': [-0.07804739475250244, 0.1498992145061493,, ...]}
33
    ```
34
    """
35

36
    def __init__(
1✔
37
        self,
38
        model: str = "sentence-transformers/all-mpnet-base-v2",
39
        device: Optional[ComponentDevice] = None,
40
        token: Optional[Secret] = Secret.from_env_var(["HF_API_TOKEN", "HF_TOKEN"], strict=False),
41
        prefix: str = "",
42
        suffix: str = "",
43
        batch_size: int = 32,
44
        progress_bar: bool = True,
45
        normalize_embeddings: bool = False,
46
        trust_remote_code: bool = False,
47
    ):
48
        """
49
        Create a SentenceTransformersTextEmbedder component.
50

51
        :param model:
52
            The model to use for calculating embeddings.
53
            Specify the path to a local model or the ID of the model on Hugging Face.
54
        :param device:
55
            Overrides the default device used to load the model.
56
        :param token:
57
            An API token to use private models from Hugging Face.
58
        :param prefix:
59
            A string to add at the beginning of each text to be embedded.
60
            You can use it to prepend the text with an instruction, as required by some embedding models,
61
            such as E5 and bge.
62
        :param suffix:
63
            A string to add at the end of each text to embed.
64
        :param batch_size:
65
            Number of texts to embed at once.
66
        :param progress_bar:
67
            If `True`, shows a progress bar for calculating embeddings.
68
            If `False`, disables the progress bar.
69
        :param normalize_embeddings:
70
            If `True`, returned vectors have a length of 1.
71
        :param trust_remote_code:
72
            If `False`, permits only Hugging Face verified model architectures.
73
            If `True`, permits custom models and scripts.
74
        """
75

76
        self.model = model
1✔
77
        self.device = ComponentDevice.resolve_device(device)
1✔
78
        self.token = token
1✔
79
        self.prefix = prefix
1✔
80
        self.suffix = suffix
1✔
81
        self.batch_size = batch_size
1✔
82
        self.progress_bar = progress_bar
1✔
83
        self.normalize_embeddings = normalize_embeddings
1✔
84
        self.trust_remote_code = trust_remote_code
1✔
85

86
    def _get_telemetry_data(self) -> Dict[str, Any]:
1✔
87
        """
88
        Data that is sent to Posthog for usage analytics.
89
        """
90
        return {"model": self.model}
×
91

92
    def to_dict(self) -> Dict[str, Any]:
1✔
93
        """
94
        Serializes the component to a dictionary.
95

96
        :returns:
97
            Dictionary with serialized data.
98
        """
99
        return default_to_dict(
1✔
100
            self,
101
            model=self.model,
102
            device=self.device.to_dict(),
103
            token=self.token.to_dict() if self.token else None,
104
            prefix=self.prefix,
105
            suffix=self.suffix,
106
            batch_size=self.batch_size,
107
            progress_bar=self.progress_bar,
108
            normalize_embeddings=self.normalize_embeddings,
109
            trust_remote_code=self.trust_remote_code,
110
        )
111

112
    @classmethod
1✔
113
    def from_dict(cls, data: Dict[str, Any]) -> "SentenceTransformersTextEmbedder":
1✔
114
        """
115
        Deserializes the component from a dictionary.
116

117
        :param data:
118
            Dictionary to deserialize from.
119
        :returns:
120
            Deserialized component.
121
        """
122
        init_params = data["init_parameters"]
1✔
123
        if init_params.get("device") is not None:
1✔
124
            init_params["device"] = ComponentDevice.from_dict(init_params["device"])
1✔
125
        deserialize_secrets_inplace(init_params, keys=["token"])
1✔
126
        return default_from_dict(cls, data)
1✔
127

128
    def warm_up(self):
1✔
129
        """
130
        Initializes the component.
131
        """
132
        if not hasattr(self, "embedding_backend"):
1✔
133
            self.embedding_backend = _SentenceTransformersEmbeddingBackendFactory.get_embedding_backend(
1✔
134
                model=self.model,
135
                device=self.device.to_torch_str(),
136
                auth_token=self.token,
137
                trust_remote_code=self.trust_remote_code,
138
            )
139

140
    @component.output_types(embedding=List[float])
1✔
141
    def run(self, text: str):
1✔
142
        """
143
        Embed a single string.
144

145
        :param text:
146
            Text to embed.
147

148
        :returns:
149
            A dictionary with the following keys:
150
            - `embedding`: The embedding of the input text.
151
        """
152
        if not isinstance(text, str):
1✔
153
            raise TypeError(
1✔
154
                "SentenceTransformersTextEmbedder expects a string as input."
155
                "In case you want to embed a list of Documents, please use the SentenceTransformersDocumentEmbedder."
156
            )
157
        if not hasattr(self, "embedding_backend"):
1✔
158
            raise RuntimeError("The embedding model has not been loaded. Please call warm_up() before running.")
×
159

160
        text_to_embed = self.prefix + text + self.suffix
1✔
161
        embedding = self.embedding_backend.embed(
1✔
162
            [text_to_embed],
163
            batch_size=self.batch_size,
164
            show_progress_bar=self.progress_bar,
165
            normalize_embeddings=self.normalize_embeddings,
166
        )[0]
167
        return {"embedding": embedding}
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc