• Home
  • Features
  • Pricing
  • Docs
  • Announcements
  • Sign In

digiteinfotech / kairon / 12112350604

02 Dec 2024 03:52AM UTC coverage: 89.891% (-0.04%) from 89.932%
12112350604

Pull #1611

github

web-flow
Merge 9176d03d1 into f2f296b80
Pull Request #1611: Mail channel implementation

383 of 434 new or added lines in 15 files covered. (88.25%)

12 existing lines in 2 files now uncovered.

24141 of 26856 relevant lines covered (89.89%)

0.9 hits per line

Source File
Press 'n' to go to next uncovered line, 'b' for previous

94.92
/kairon/shared/llm/processor.py
1
import time
1✔
2
import urllib.parse
1✔
3
from secrets import randbelow, choice
1✔
4
from typing import Text, Dict, List, Tuple, Union
1✔
5
from urllib.parse import urljoin
1✔
6

7
import litellm
1✔
8
from loguru import logger as logging
1✔
9
from mongoengine.base import BaseList
1✔
10
from tiktoken import get_encoding
1✔
11
from tqdm import tqdm
1✔
12

13
from kairon.exceptions import AppException
1✔
14
from kairon.shared.actions.utils import ActionUtility
1✔
15
from kairon.shared.admin.data_objects import LLMSecret
1✔
16
from kairon.shared.admin.processor import Sysadmin
1✔
17
from kairon.shared.cognition.data_objects import CognitionData
1✔
18
from kairon.shared.cognition.processor import CognitionDataProcessor
1✔
19
from kairon.shared.data.constant import DEFAULT_LLM
1✔
20
from kairon.shared.data.constant import DEFAULT_SYSTEM_PROMPT, DEFAULT_CONTEXT_PROMPT
1✔
21
from kairon.shared.llm.base import LLMBase
1✔
22
from kairon.shared.llm.data_objects import LLMLogs
1✔
23
from kairon.shared.llm.logger import LiteLLMLogger
1✔
24
from kairon.shared.models import CognitionDataType
1✔
25
from kairon.shared.rest_client import AioRestClient
1✔
26
from kairon.shared.utils import Utility
1✔
27
from http import HTTPStatus
1✔
28

29
litellm.callbacks = [LiteLLMLogger()]
1✔
30

31

32
class LLMProcessor(LLMBase):
1✔
33
    __embedding__ = 1536
1✔
34

35
    def __init__(self, bot: Text, llm_type: str):
1✔
36
        super().__init__(bot)
1✔
37
        self.db_url = Utility.environment['vector']['db']
1✔
38
        self.headers = {}
1✔
39
        if Utility.environment['vector']['key']:
1✔
40
            self.headers = {"api-key": Utility.environment['vector']['key']}
1✔
41
        self.suffix = "_faq_embd"
1✔
42
        self.llm_type = llm_type
1✔
43
        self.vector_config = {'size': self.__embedding__, 'distance': 'Cosine'}
1✔
44
        self.llm_secret = Sysadmin.get_llm_secret(llm_type, bot)
1✔
45

46
        if llm_type != DEFAULT_LLM:
1✔
47
            self.llm_secret_embedding = Sysadmin.get_llm_secret(DEFAULT_LLM, bot)
1✔
48
        else:
49
            self.llm_secret_embedding = self.llm_secret
1✔
50

51
        self.tokenizer = get_encoding("cl100k_base")
1✔
52
        self.EMBEDDING_CTX_LENGTH = 8191
1✔
53
        self.__logs = []
1✔
54

55
    async def train(self, user, *args, **kwargs) -> Dict:
1✔
56
        invocation = kwargs.pop('invocation', None)
1✔
57
        await self.__delete_collections()
1✔
58
        count = 0
1✔
59
        processor = CognitionDataProcessor()
1✔
60
        batch_size = 100
1✔
61

62
        collections_data = CognitionData.objects(bot=self.bot)
1✔
63
        collection_groups = {}
1✔
64
        for content in collections_data:
1✔
65
            content_dict = content.to_mongo()
1✔
66
            collection_name = content_dict.get('collection') or ""
1✔
67
            if collection_name not in collection_groups:
1✔
68
                collection_groups[collection_name] = []
1✔
69
            collection_groups[collection_name].append(content_dict)
1✔
70

71
        for collection_name, contents in collection_groups.items():
1✔
72
            collection = f"{self.bot}_{collection_name}{self.suffix}" if collection_name else f"{self.bot}{self.suffix}"
1✔
73
            await self.__create_collection__(collection)
1✔
74

75
            for i in tqdm(range(0, len(contents), batch_size), desc="Training FAQ"):
1✔
76
                batch_contents = contents[i:i + batch_size]
1✔
77

78
                embedding_payloads = []
1✔
79
                search_payloads = []
1✔
80
                vector_ids = []
1✔
81

82
                for content in batch_contents:
1✔
83
                    if content['content_type'] == CognitionDataType.json.value:
1✔
84
                        metadata = processor.find_matching_metadata(self.bot, content['data'],
1✔
85
                                                                    content.get('collection'))
86
                        search_payload, embedding_payload = Utility.retrieve_search_payload_and_embedding_payload(
1✔
87
                            content['data'], metadata)
88
                    else:
89
                        search_payload, embedding_payload = {'content': content["data"]}, content["data"]
1✔
90

91
                    embedding_payloads.append(embedding_payload)
1✔
92
                    search_payloads.append(search_payload)
1✔
93
                    vector_ids.append(content['vector_id'])
1✔
94

95
                embeddings = await self.get_embedding(embedding_payloads, user, invocation=invocation)
1✔
96

97
                points = [{'id': vector_ids[idx], 'vector': embeddings[idx], 'payload': search_payloads[idx]}
1✔
98
                          for idx in range(len(vector_ids))]
99
                await self.__collection_upsert__(collection, {'points': points},
1✔
100
                                                 err_msg="Unable to train FAQ! Contact support")
101
                count += len(batch_contents)
1✔
102

103
        return {"faq": count}
1✔
104

105
    async def predict(self, query: Text, user, *args, **kwargs) -> Tuple:
1✔
106
        start_time = time.time()
1✔
107
        embeddings_created = False
1✔
108
        invocation = kwargs.pop('invocation', None)
1✔
109
        llm_type = kwargs.pop('llm_type', DEFAULT_LLM)
1✔
110
        try:
1✔
111
            query_embedding = await self.get_embedding(query, user, invocation=invocation)
1✔
112
            embeddings_created = True
1✔
113

114
            system_prompt = kwargs.pop('system_prompt', DEFAULT_SYSTEM_PROMPT)
1✔
115
            context_prompt = kwargs.pop('context_prompt', DEFAULT_CONTEXT_PROMPT)
1✔
116

117
            context = await self.__attach_similarity_prompt_if_enabled(query_embedding, context_prompt, **kwargs)
1✔
118
            answer = await self.__get_answer(query, system_prompt, context, user, invocation=invocation,llm_type = llm_type, **kwargs)
1✔
119
            response = {"content": answer}
1✔
120
        except Exception as e:
1✔
121
            logging.exception(e)
1✔
122
            if embeddings_created:
1✔
123
                failure_stage = "Retrieving chat completion for the provided query."
1✔
124
            else:
125
                failure_stage = "Creating a new embedding for the provided query."
1✔
126
            self.__logs.append({'error': f"{failure_stage} {str(e)}"})
1✔
127
            response = {"is_failure": True, "exception": str(e), "content": None}
1✔
128

129
        end_time = time.time()
1✔
130
        elapsed_time = end_time - start_time
1✔
131
        return response, elapsed_time
1✔
132

133
    def truncate_text(self, texts: List[Text]) -> List[Text]:
1✔
134
        """
135
        Truncate multiple texts to 8191 tokens for openai
136
        """
137
        truncated_texts = []
1✔
138

139
        for text in texts:
1✔
140
            tokens = self.tokenizer.encode(text)[:self.EMBEDDING_CTX_LENGTH]
1✔
141
            truncated_texts.append(self.tokenizer.decode(tokens))
1✔
142

143
        return truncated_texts
1✔
144

145
    async def get_embedding(self, texts: Union[Text, List[Text]], user, **kwargs):
1✔
146
        """
147
        Get embeddings for a batch of texts.
148
        """
149
        is_single_text = isinstance(texts, str)
1✔
150
        if is_single_text:
1✔
151
            texts = [texts]
1✔
152

153
        truncated_texts = self.truncate_text(texts)
1✔
154

155
        result = await litellm.aembedding(
1✔
156
            model="text-embedding-3-small",
157
            input=truncated_texts,
158
            metadata={'user': user, 'bot': self.bot, 'invocation': kwargs.get("invocation")},
159
            api_key=self.llm_secret_embedding.get('api_key'),
160
            num_retries=3
161
        )
162

163
        embeddings = [embedding["embedding"] for embedding in result["data"]]
1✔
164

165
        if is_single_text:
1✔
166
            return embeddings[0]
1✔
167

168
        return embeddings
1✔
169

170
    async def __parse_completion_response(self, response, **kwargs):
1✔
171
        if kwargs.get("stream"):
×
172
            formatted_response = ''
×
173
            msg_choice = randbelow(kwargs.get("n", 1))
×
174
            if response["choices"][0].get("index") == msg_choice and response["choices"][0]['delta'].get('content'):
×
175
                formatted_response = f"{response['choices'][0]['delta']['content']}"
×
176
        else:
177
            msg_choice = choice(response['choices'])
×
178
            formatted_response = msg_choice['message']['content']
×
179
        return formatted_response
×
180

181
    async def __get_completion(self, messages, hyperparameters, user, **kwargs):
1✔
182
        body = {
1✔
183
            'messages': messages,
184
            'hyperparameters': hyperparameters,
185
            'user': user,
186
            'invocation': kwargs.get("invocation")
187
        }
188

189
        timeout = Utility.environment['llm'].get('request_timeout', 30)
1✔
190
        http_response, status_code, elapsed_time, _ = await ActionUtility.execute_request_async(http_url=f"{Utility.environment['llm']['url']}/{urllib.parse.quote(self.bot)}/completion/{self.llm_type}",
1✔
191
                                                                     request_method="POST",
192
                                                                     request_body=body,
193
                                                                     timeout=timeout)
194

195
        logging.info(f"LLM request completed in {elapsed_time} for bot: {self.bot}")
1✔
196
        if status_code not in [200, 201, 202, 203, 204]:
1✔
197
            raise Exception(HTTPStatus(status_code).phrase)
1✔
198

199
        if isinstance(http_response, dict):
1✔
200
            return http_response.get("formatted_response"), http_response.get("response")
1✔
201
        else:
202
            return http_response, http_response
×
203

204

205
    async def __get_answer(self, query, system_prompt: Text, context: Text, user, **kwargs):
1✔
206
        use_query_prompt = False
1✔
207
        query_prompt = ''
1✔
208
        invocation = kwargs.pop('invocation')
1✔
209
        llm_type = kwargs.get('llm_type')
1✔
210
        if kwargs.get('query_prompt', {}):
1✔
211
            query_prompt_dict = kwargs.pop('query_prompt')
1✔
212
            query_prompt = query_prompt_dict.get('query_prompt', '')
1✔
213
            use_query_prompt = query_prompt_dict.get('use_query_prompt')
1✔
214
        previous_bot_responses = kwargs.get('previous_bot_responses')
1✔
215
        hyperparameters = kwargs['hyperparameters']
1✔
216
        instructions = kwargs.get('instructions', [])
1✔
217
        instructions = '\n'.join(instructions)
1✔
218

219
        if use_query_prompt and query_prompt:
1✔
220
            query = await self.__rephrase_query(query, system_prompt, query_prompt,
1✔
221
                                                hyperparameters=hyperparameters,
222
                                                user=user,
223
                                                invocation=f"{invocation}_rephrase")
224
        messages = [
1✔
225
            {"role": "system", "content": system_prompt},
226
        ]
227
        if previous_bot_responses:
1✔
228
            messages.extend(previous_bot_responses)
1✔
229
            query = self.modify_user_message_for_perplexity(query, llm_type, hyperparameters)
1✔
230
        messages.append({"role": "user", "content": f"{context} \n{instructions} \nQ: {query} \nA:"}) if instructions \
1✔
231
            else messages.append({"role": "user", "content": f"{context} \nQ: {query} \nA:"})
232
        completion, raw_response = await self.__get_completion(messages=messages,
1✔
233
                                                               hyperparameters=hyperparameters,
234
                                                               user=user,
235
                                                               invocation=invocation)
236
        self.__logs.append({'messages': messages, 'raw_completion_response': raw_response,
1✔
237
                            'type': 'answer_query', 'hyperparameters': hyperparameters})
238
        return completion
1✔
239

240
    async def __rephrase_query(self, query, system_prompt: Text, query_prompt: Text, user, **kwargs):
1✔
241
        invocation = kwargs.pop('invocation')
1✔
242
        messages = [
1✔
243
            {"role": "system", "content": system_prompt},
244
            {"role": "user", "content": f"{query_prompt}\n\n Q: {query}\n A:"}
245
        ]
246
        hyperparameters = kwargs['hyperparameters']
1✔
247

248
        completion, raw_response = await self.__get_completion(messages=messages,
1✔
249
                                                               hyperparameters=hyperparameters,
250
                                                               user=user,
251
                                                               invocation=invocation)
252
        self.__logs.append({'messages': messages, 'raw_completion_response': raw_response,
1✔
253
                            'type': 'rephrase_query', 'hyperparameters': hyperparameters})
254
        return completion
1✔
255

256
    async def __delete_collections(self):
1✔
257
        client = AioRestClient(False)
1✔
258
        try:
1✔
259
            response = await client.request(http_url=urljoin(self.db_url, "/collections"),
1✔
260
                                            request_method="GET",
261
                                            headers=self.headers,
262
                                            timeout=5)
263
            if response.get('result'):
1✔
264
                for collection in response['result'].get('collections') or []:
1✔
265
                    if collection['name'].startswith(self.bot):
1✔
266
                        await client.request(http_url=urljoin(self.db_url, f"/collections/{collection['name']}"),
1✔
267
                                             request_method="DELETE",
268
                                             headers=self.headers,
269
                                             return_json=False,
270
                                             timeout=5)
271
        finally:
272
            await client.cleanup()
1✔
273

274
    async def __create_collection__(self, collection_name: Text):
1✔
275
        await AioRestClient().request(http_url=urljoin(self.db_url, f"/collections/{collection_name}"),
1✔
276
                                      request_method="PUT",
277
                                      headers=self.headers,
278
                                      request_body={'name': collection_name, 'vectors': self.vector_config},
279
                                      return_json=False,
280
                                      timeout=5)
281

282
    async def __collection_upsert__(self, collection_name: Text, data: Dict, err_msg: Text, raise_err=True):
1✔
283
        client = AioRestClient()
1✔
284
        response = await client.request(http_url=urljoin(self.db_url, f"/collections/{collection_name}/points"),
1✔
285
                                        request_method="PUT",
286
                                        headers=self.headers,
287
                                        request_body=data,
288
                                        return_json=True,
289
                                        timeout=5)
290
        if not response.get('result'):
1✔
291
            if "status" in response:
1✔
292
                logging.exception(response['status'].get('error'))
1✔
293
                if raise_err:
1✔
294
                    raise AppException(err_msg)
1✔
295

296
    async def __collection_exists__(self, collection_name: Text) -> bool:
1✔
297
        """Check if a collection exists."""
298
        try:
1✔
299
            response = await AioRestClient().request(
1✔
300
                http_url=urljoin(self.db_url, f"/collections/{collection_name}"),
301
                request_method="GET",
302
                headers=self.headers,
303
                return_json=True,
304
                timeout=5
305
            )
306
            return response.get('status') == "ok"
1✔
307
        except Exception as e:
1✔
308
            logging.info(e)
1✔
309
            return False
1✔
310

311
    async def __collection_search__(self, collection_name: Text, vector: List, limit: int, score_threshold: float):
1✔
312
        client = AioRestClient()
1✔
313
        response = await client.request(
1✔
314
            http_url=urljoin(self.db_url, f"/collections/{collection_name}/points/search"),
315
            request_method="POST",
316
            headers=self.headers,
317
            request_body={'vector': vector, 'limit': limit, 'with_payload': True, 'score_threshold': score_threshold},
318
            return_json=True,
319
            timeout=5)
320
        return response
1✔
321

322
    @property
1✔
323
    def logs(self):
1✔
324
        return self.__logs
1✔
325

326
    async def __attach_similarity_prompt_if_enabled(self, query_embedding, context_prompt, **kwargs):
1✔
327
        similarity_prompt = kwargs.pop('similarity_prompt')
1✔
328
        for similarity_context_prompt in similarity_prompt:
1✔
329
            use_similarity_prompt = similarity_context_prompt.get('use_similarity_prompt')
1✔
330
            similarity_prompt_name = similarity_context_prompt.get('similarity_prompt_name')
1✔
331
            similarity_prompt_instructions = similarity_context_prompt.get('similarity_prompt_instructions')
1✔
332
            limit = similarity_context_prompt.get('top_results', 10)
1✔
333
            score_threshold = similarity_context_prompt.get('similarity_threshold', 0.70)
1✔
334
            extracted_values = []
1✔
335
            if use_similarity_prompt:
1✔
336
                if similarity_context_prompt.get('collection') == 'default':
1✔
337
                    collection_name = f"{self.bot}{self.suffix}"
1✔
338
                else:
339
                    collection_name = f"{self.bot}_{similarity_context_prompt.get('collection')}{self.suffix}"
1✔
340
                search_result = await self.__collection_search__(collection_name, vector=query_embedding, limit=limit,
1✔
341
                                                                 score_threshold=score_threshold)
342

343
                for entry in search_result['result']:
1✔
344
                    if 'content' not in entry['payload']:
1✔
345
                        extracted_payload = {}
1✔
346
                        for key, value in entry['payload'].items():
1✔
347
                            if key != 'collection_name':
1✔
348
                                extracted_payload[key] = value
1✔
349
                        extracted_values.append(extracted_payload)
1✔
350
                    else:
351
                        extracted_values.append(entry['payload']['content'])
1✔
352
                if extracted_values:
1✔
353
                    similarity_context = f"Instructions on how to use {similarity_prompt_name}:\n{extracted_values}\n{similarity_prompt_instructions}\n"
1✔
354
                    context_prompt = f"{context_prompt}\n{similarity_context}"
1✔
355
        return context_prompt
1✔
356

357
    @staticmethod
1✔
358
    def get_logs(bot: str, start_idx: int = 0, page_size: int = 10):
1✔
359
        """
360
        Get all logs for data importer event.
361
        @param bot: bot id.
362
        @param start_idx: start index
363
        @param page_size: page size
364
        @return: list of logs.
365
        """
366
        for log in LLMLogs.objects(metadata__bot=bot).order_by("-start_time").skip(start_idx).limit(page_size):
1✔
367
            llm_log = log.to_mongo().to_dict()
1✔
368
            llm_log.pop('_id')
1✔
369
            yield llm_log
1✔
370

371
    @staticmethod
1✔
372
    def get_row_count(bot: str):
1✔
373
        """
374
        Gets the count of rows in a LLMLogs for a particular bot.
375
        :param bot: bot id
376
        :return: Count of rows
377
        """
378
        return LLMLogs.objects(metadata__bot=bot).count()
1✔
379

380
    @staticmethod
1✔
381
    def fetch_llm_metadata(bot: str):
1✔
382
        """
383
        Fetches the llm_type and corresponding models for a particular bot.
384
        :param bot: bot id
385
        :return: dictionary where each key is a llm_type and the value is a list of models.
386
        """
387
        metadata = Utility.llm_metadata
1✔
388
        llm_types = metadata.keys()
1✔
389

390
        for llm_type in llm_types:
1✔
391
            secret = LLMSecret.objects(bot=bot, llm_type=llm_type).first()
1✔
392
            if not secret:
1✔
393
                secret = LLMSecret.objects(llm_type=llm_type, bot__exists=False).first()
1✔
394

395
            if secret:
1✔
396
                models = list(secret.models) if isinstance(secret.models, BaseList) else secret.models
1✔
397
            else:
398
                models = []
1✔
399

400
            metadata[llm_type]['properties']['model']['enum'] = models
1✔
401

402
        return metadata
1✔
403

404
    @staticmethod
1✔
405
    def modify_user_message_for_perplexity(user_msg: str, llm_type: str, hyperparameters: Dict) -> str:
1✔
406
        """
407
        Modify the user message if the LLM type is 'perplexity' and a search domain filter is provided.
408
        :param user_msg: The original user message.
409
        :param llm_type: The LLM type to check if it's 'perplexity'.
410
        :param hyperparameters: LLM hyperparameters
411
        :return: Modified user message.
412
        """
413
        if llm_type == 'perplexity':
1✔
UNCOV
414
            search_domain_filter = hyperparameters.get('search_domain_filter')
×
UNCOV
415
            if search_domain_filter:
×
UNCOV
416
                search_domain_filter_str = "|".join(
×
417
                    [domain.strip() for domain in search_domain_filter if domain.strip()]
418
                )
UNCOV
419
                user_msg = f"{user_msg} inurl:{search_domain_filter_str}"
×
420
        return user_msg
1✔
STATUS · Troubleshooting · Open an Issue · Sales · Support · CAREERS · ENTERPRISE · START FREE · SCHEDULE DEMO
ANNOUNCEMENTS · TWITTER · TOS & SLA · Supported CI Services · What's a CI service? · Automated Testing

© 2025 Coveralls, Inc