地平線まで行ってくる。

記録あるいは忘備録。時には検討事項。

日本語ModernBERTのFT版、AMBERをLangchainでembeddingとしてColabで使ってみる。

SB Insituionsさんの日本語ModernBERTをFinetuningしたtext embedding modelである、AMBERをlanchainで利用してみます。Retrievalのベンチマーク成績も高く、今後利用してみたいので、基本的な部分を勉強のため実装してみます。日本語の環境がだんだん整っていくのは嬉しいですね。

 

huggingface.co

 

LangchainのEmbeddingsのbaseとして、構築します。もちろん、LLMにも聞きながらの実装です。ドキュメント、クエリーのエンコード用のプロンプトをそれぞれ組み込みます。

 

from typing import List, Dict, Any, Optional
from langchain_core.embeddings import Embeddings
from sentence_transformers import SentenceTransformer

class AMBEREmbeddings(Embeddings):
    """
    LangChain用のAMBER埋め込みモデルラッパークラス
    """
    
    def __init__(
        self,
        model_name: str = "retrieva-jp/amber-large",
        query_prompt_name: str = "Retrieval-query",
        document_prompt_name: str = "Retrieval-passage",
        cache_folder: Optional[str] = None,
        model_kwargs: Optional[Dict[str, Any]] = None,
        encode_kwargs: Optional[Dict[str, Any]] = None,
    ):
        """
        初期化メソッド
        
        Args:
            model_name: Hugging Faceモデル名
            query_prompt_name: クエリエンコード時のプロンプト名
            document_prompt_name: ドキュメントエンコード時のプロンプト名
            cache_folder: モデルキャッシュディレクト
            model_kwargs: SentenceTransformerモデル初期化用の追加引数
            encode_kwargs: encode関数用の追加引数
        """
        self.model_name = model_name
        self.query_prompt_name = query_prompt_name
        self.document_prompt_name = document_prompt_name
        self.cache_folder = cache_folder
        self.model_kwargs = model_kwargs or {}
        self.encode_kwargs = encode_kwargs or {}
        
        # SentenceTransformerモデルの初期化
        self._model = SentenceTransformer(
            model_name_or_path=model_name,
            cache_folder=cache_folder,
            **self.model_kwargs
        )
    
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        """
        ドキュメントテキストをベクトル埋め込みに変換
        
        Args:
            texts: 埋め込むドキュメントのリスト
            
        Returns:
            埋め込みベクトルのリスト
        """
        # ドキュメント用プロンプトを使用してエンコード
        encode_kwargs = {**self.encode_kwargs, "prompt_name": self.document_prompt_name}
        embeddings = self._model.encode(texts, **encode_kwargs)
        return embeddings.tolist()
    
    def embed_query(self, text: str) -> List[float]:
        """
        クエリテキストをベクトル埋め込みに変換
        
        Args:
            text: 埋め込むクエリテキスト
            
        Returns:
            クエリの埋め込みベクトル
        """
        # クエリ用プロンプトを使用してエンコード
        encode_kwargs = {**self.encode_kwargs, "prompt_name": self.query_prompt_name}
        embedding = self._model.encode(text, **encode_kwargs)
        return embedding.tolist()

 

テストに使ったColab:

gist.github.com