Hello,
I’m serving a Langchain chain with FastAPI with the approach here:
def get_chain(
vectorstore: VectorStore, question_handler, stream_handler, tracing: bool = False
) -> ConversationalRetrievalChain: # <== CHANGE THE TYPE
"""Create a ChatVectorDBChain for question/answering."""
# Construct a ChatVectorDBChain with a streaming llm for combine docs
# and a separate, non-streaming llm for question generation
manager = AsyncCallbackManager([])
question_manager = AsyncCallbackManager([question_handler])
stream_manager = AsyncCallbackManager([stream_handler])
if tracing:
tracer = LangChainTracer()
tracer.load_default_session()
manager.add_handler(tracer)
question_manager.add_handler(tracer)
stream_manager.add_handler(tracer)
question_gen_llm = OpenAI(
temperature=0,
verbose=True,
callback_manager=question_manager,
)
streaming_llm = OpenAI(
streaming=True,
callback_manager=stream_manager,
verbose=True,
temperature=0,
)
question_generator = LLMChain(
llm=question_gen_llm, prompt=CONDENSE_QUESTION_PROMPT, callback_manager=manager
)
doc_chain = load_qa_chain(
streaming_llm, chain_type="stuff", prompt=QA_PROMPT, callback_manager=manager
)
qa = ConversationalRetrievalChain( # <==CHANGE ConversationalRetrievalChain instead of ChatVectorDBChain
# vectorstore=vectorstore, # <== REMOVE THIS
retriever=vectorstore.as_retriever(), # <== ADD THIS
combine_docs_chain=doc_chain,
question_generator=question_generator,
callback_manager=manager,
)
return qa
where the caller is:
app = FastAPI()
templates = Jinja2Templates(directory="templates")
vectorstore: Optional[VectorStore] = None
@app.on_event("startup")
async def startup_event():
logging.info("loading vectorstore")
if not Path("vectorstore.pkl").exists():
raise ValueError("vectorstore.pkl does not exist, please run ingest.py first")
with open("vectorstore.pkl", "rb") as f:
global vectorstore
vectorstore = pickle.load(f)
@app.get("/")
async def get(request: Request):
return templates.TemplateResponse("index.html", {"request": request})
@app.websocket("/chat")
async def websocket_endpoint(websocket: WebSocket):
await websocket.accept()
question_handler = QuestionGenCallbackHandler(websocket)
stream_handler = StreamingLLMCallbackHandler(websocket)
chat_history = []
qa_chain = get_chain(vectorstore, question_handler, stream_handler)
# Use the below line instead of the above line to enable tracing
# Ensure `langchain-server` is running
# qa_chain = get_chain(vectorstore, question_handler, stream_handler, tracing=True)
while True:
try:
# Receive and send back the client message
question = await websocket.receive_text()
resp = ChatResponse(sender="you", message=question, type="stream")
await websocket.send_json(resp.dict())
# Construct a response
start_resp = ChatResponse(sender="bot", message="", type="start")
await websocket.send_json(start_resp.dict())
result = await qa_chain.acall(
{"question": question, "chat_history": chat_history}
)
chat_history.append((question, result["answer"]))
end_resp = ChatResponse(sender="bot", message="", type="end")
await websocket.send_json(end_resp.dict())
except WebSocketDisconnect:
logging.info("websocket disconnect")
break
except Exception as e:
logging.error(e)
resp = ChatResponse(
sender="bot",
message="Sorry, something went wrong. Try again.",
type="error",
)
await websocket.send_json(resp.dict())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=9000)
This successfully works using langchain.vectorstores.Weaviate(…).as_retriever(), but I cannot seem to get it to work with langchain.retrievers.weaviate_hybrid_search.WeaviateHybridSearchRetriever. I can run the chain synchronously though using the hybrid retriever. I’m also having trouble producing an informative traceback here, so unfortunately I cannot not provide one. Can some one tell me if this is expected behavior? Are there missing asynchronous implementations?
Edit: is it because aget_relevant_documents
is not implemented?