llamaindex update + unpacking archives in data
This commit is contained in:
71
services/rag/llamaindex/helpers/embedding.py
Normal file
71
services/rag/llamaindex/helpers/embedding.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
from llama_index.core.embeddings import BaseEmbedding
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
class OpenAICompatibleEmbedding(BaseEmbedding):
|
||||
model: str = Field(...)
|
||||
api_key: str = Field(...)
|
||||
api_base: str = Field(...)
|
||||
timeout: int = Field(default=60)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str,
|
||||
api_base: str,
|
||||
timeout: int = 60,
|
||||
):
|
||||
self.model = model
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base.rstrip("/")
|
||||
self.timeout = timeout
|
||||
|
||||
# ---------- low-level call ----------
|
||||
|
||||
def _embed(self, texts: List[str]) -> List[List[float]]:
|
||||
url = f"{self.api_base}/embeddings"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": texts,
|
||||
}
|
||||
|
||||
resp = requests.post(
|
||||
url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
data = resp.json()
|
||||
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
|
||||
# ---------- document embeddings ----------
|
||||
|
||||
def _get_text_embedding(self, text: str) -> List[float]:
|
||||
return self._embed([text])[0]
|
||||
|
||||
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts)
|
||||
|
||||
async def _aget_text_embedding(self, text: str) -> List[float]:
|
||||
return self._get_text_embedding(text)
|
||||
|
||||
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._get_text_embeddings(texts)
|
||||
|
||||
# ---------- query embeddings (REQUIRED) ----------
|
||||
|
||||
def _get_query_embedding(self, query: str) -> List[float]:
|
||||
# bge-m3 uses same embedding for query & doc
|
||||
return self._embed([query])[0]
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> List[float]:
|
||||
return self._get_query_embedding(query)
|
||||
Reference in New Issue
Block a user