/**
* RAGProxyServlet
* Copyright 2024 by Michael Peter Christen
* First released 17.05.2024 at http://yacy.net
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with this program in the file lgpl21.txt
* If not, see .
*/
package net.yacy.http.servlets;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;
import net.yacy.cora.federate.solr.connector.EmbeddedSolrConnector;
import net.yacy.search.Switchboard;
import net.yacy.search.schema.CollectionSchema;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.common.SolrDocument;
import org.apache.solr.common.SolrDocumentList;
import org.apache.solr.common.SolrException;
import org.apache.solr.servlet.cache.Method;
import javax.servlet.ServletException;
import javax.servlet.ServletOutputStream;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.net.HttpURLConnection;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
/**
* This class implements a Retrieval Augmented Generation ("RAG") proxy which uses a YaCy search index
* to enrich a chat with search results. The
*/
public class RAGProxyServlet extends HttpServlet {
private static final long serialVersionUID = 3411544789759603107L;
private static String[] STOPTOKENS = new String[]{"[/INST]", "<|im_end|>", "<|end_of_turn|>", "<|eot_id|>", "<|end_header_id|>", "", "", "<|end|>"};
private static Boolean LLM_ENABLED = false;
private static Boolean LLM_CONTROL_OLLAMA = true;
private static Boolean LLM_ATTACH_QUERY = false; // instructs the proxy to attach the prompt generated to do the RAG search
private static Boolean LLM_ATTACH_REFERENCES = false; // instructs the proxy to attach a list of sources that had been used in RAG
private static String LLM_LANGUAGE = "en"; // used to select proper language in RAG augmentation
private static String LLM_SYSTEM_PREFIX = "\n\nYou may receive additional expert knowledge in the user prompt after a 'Additional Information' headline to enhance your knowledge. Use it only if applicable.";
private static String LLM_USER_PREFIX = "\n\nAdditional Information:\n\nbelow you find a collection of texts that might be useful to generate a response. Do not discuss these documents, just use them to answer the question above.\n\n";
private static String LLM_API_HOST = "http://localhost:11434"; // Ollama port; install ollama from https://ollama.com/
private static String LLM_QUERY_MODEL = "phi3:3.8b";
private static String LLM_ANSWER_MODEL = "llama3:8b"; // or "phi3:3.8b" i.e. on a Raspberry Pi 5
private static Boolean LLM_API_MODEL_OVERWRITING = true; // if true, the value configured in YaCy overwrites the client model
private static String LLM_API_KEY = ""; // not required; option to use this class to use a OpenAI API
@Override
public void service(ServletRequest request, ServletResponse response) throws IOException, ServletException {
response.setContentType("application/json;charset=utf-8");
HttpServletResponse hresponse = (HttpServletResponse) response;
HttpServletRequest hrequest = (HttpServletRequest) request;
// Add CORS headers
hresponse.setHeader("Access-Control-Allow-Origin", "*");
hresponse.setHeader("Access-Control-Allow-Methods", "POST, GET, OPTIONS, DELETE");
hresponse.setHeader("Access-Control-Allow-Headers", "Content-Type, Authorization");
final Method reqMethod = Method.getMethod(hrequest.getMethod());
if (reqMethod == Method.OTHER) {
// required to handle CORS
hresponse.setStatus(HttpServletResponse.SC_OK);
return;
}
// We expect a POST request
if (reqMethod != Method.POST) {
hresponse.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED);
return;
}
// get the output stream early to be able to generate messages to the user before the actual retrieval starts
ServletOutputStream out = response.getOutputStream();
// read the body of the request and parse it as JSON
BufferedReader reader = request.getReader();
StringBuilder bodyBuilder = new StringBuilder();
String line;
while ((line = reader.readLine()) != null) {
bodyBuilder.append(line);
}
String body = bodyBuilder.toString();
JSONObject bodyObject;
try {
// get system message and user prompt
bodyObject = new JSONObject(body);
String model = bodyObject.optString("model", LLM_ANSWER_MODEL); // we need a switch to allow overwriting
JSONArray messages = bodyObject.optJSONArray("messages");
JSONObject systemObject = messages.getJSONObject(0);
String system = systemObject.optString("content", ""); // the system prompt
JSONObject userObject = messages.getJSONObject(messages.length() - 1);
String user = userObject.optString("content", ""); // this is the latest prompt
// modify system and user prompt here in bodyObject to enable RAG
String query = searchWordsForPrompt(LLM_QUERY_MODEL, user);
out.print(responseLine("Searching for '" + query + "'\n\n").toString() + "\n"); out.flush();
LinkedHashMap searchResults = searchResults(query, 4);
out.print(responseLine("Using the following sources for RAG:\n\n").toString() + "\n"); out.flush();
for (String s: searchResults.keySet()) {out.print(responseLine("- `" + s + "`\n").toString() + "\n"); out.flush();}
out.print(responseLine("\n").toString()); out.flush();
system += LLM_SYSTEM_PREFIX;
user += LLM_USER_PREFIX;
for (String s: searchResults.values()) user += s + "\n\n";
systemObject.put("content", system);
userObject.put("content", user);
if (LLM_API_MODEL_OVERWRITING) bodyObject.put("model", LLM_ANSWER_MODEL);
// write back modified bodyMap to body
body = bodyObject.toString();
// Open request to back-end service
URL url = new URI(LLM_API_HOST + "/v1/chat/completions").toURL();
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
if (!LLM_API_KEY.isEmpty()) {
conn.setRequestProperty("Authorization", "Bearer " + LLM_API_KEY);
}
conn.setDoOutput(true);
// write the body to back-end LLM
try (OutputStream os = conn.getOutputStream()) {
os.write(body.getBytes());
os.flush();
}
// write back response of the back-end service to the client; use status of backend-response
int status = conn.getResponseCode();
String rmessage = conn.getResponseMessage();
hresponse.setStatus(status);
if (status == 200) {
// read the response of the back-end line-by-line and write it to the client line-by-line
BufferedReader in = new BufferedReader(new InputStreamReader(conn.getInputStream()));
String inputLine;
while ((inputLine = in.readLine()) != null) {
out.print(inputLine); // i.e. data: {"id":"chatcmpl-69","object":"chat.completion.chunk","created":1715908287,"model":"llama3:8b","system_fingerprint":"fp_ollama","choices":[{"index":0,"delta":{"role":"assistant","content":"ߘ"},"finish_reason":null}]}
out.flush();
}
in.close();
}
out.close(); // close this here to end transmission
} catch (JSONException | URISyntaxException e) {
throw new IOException(e.getMessage());
}
}
private static JSONObject responseLine(String payload) {
JSONObject j = new JSONObject(true);
try {
j.put("id", "log");
j.put("object", "chat.completion.chunk");
j.put("created", System.currentTimeMillis() / 1000);
j.put("model", "log");
j.put("system_fingerprint", "YaCy");
JSONArray choices = new JSONArray();
JSONObject choice = new JSONObject(true); // {"index":0,"delta":{"role":"assistant","content":"ߘ"
choice.put("index", 0);
JSONObject delta = new JSONObject(true);
delta.put("role", "assistant");
delta.put("content", payload);
choice.put("delta", delta);
choices.put(choice);
j.put("choices", choices);
//j.put("finish_reason", null); // this is problematic with the JSON library
} catch (JSONException e) {}
return j;
}
// API Helper Methods for Ollama
private static String sendPostRequest(String endpoint, JSONObject data) throws IOException, URISyntaxException {
URL url = new URI(endpoint).toURL();
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod("POST");
conn.setRequestProperty("Content-Type", "application/json");
conn.setDoOutput(true);
try (OutputStream os = conn.getOutputStream()) {
byte[] input = data.toString().getBytes("utf-8");
os.write(input, 0, input.length);
}
int responseCode = conn.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
StringBuilder response = new StringBuilder();
String responseLine;
while ((responseLine = br.readLine()) != null) {
response.append(responseLine.trim());
}
return response.toString();
}
} else {
throw new IOException("Request failed with response code " + responseCode);
}
}
private static String sendGetRequest(String endpoint) throws IOException, URISyntaxException {
URL url = new URI(endpoint).toURL();
HttpURLConnection conn = (HttpURLConnection) url.openConnection();
conn.setRequestMethod("GET");
int responseCode = conn.getResponseCode();
if (responseCode == HttpURLConnection.HTTP_OK) {
try (BufferedReader br = new BufferedReader(new InputStreamReader(conn.getInputStream(), "utf-8"))) {
StringBuilder response = new StringBuilder();
String responseLine;
while ((responseLine = br.readLine()) != null) {
response.append(responseLine.trim());
}
return response.toString();
}
} else {
throw new IOException("Request failed with response code " + responseCode);
}
}
// OpenAI chat client, works also with llama.cpp and Ollama
public static String chat(String model, String prompt, int max_tokens) throws IOException {
JSONObject data = new JSONObject();
JSONArray messages = new JSONArray();
JSONObject systemPrompt = new JSONObject(true);
JSONObject userPrompt = new JSONObject(true);
messages.put(systemPrompt);
messages.put(userPrompt);
try {
systemPrompt.put("role", "system");
systemPrompt.put("content", "Make short answers.");
userPrompt.put("role", "user");
userPrompt.put("content", prompt);
data.put("model", model);
data.put("temperature", 0.1);
data.put("max_tokens", max_tokens);
data.put("messages", messages);
data.put("stop", new JSONArray(STOPTOKENS));
data.put("stream", false);
String response = sendPostRequest(LLM_API_HOST + "/v1/chat/completions", data);
JSONObject responseObject = new JSONObject(response);
JSONArray choices = responseObject.getJSONArray("choices");
JSONObject choice = choices.getJSONObject(0);
JSONObject message = choice.getJSONObject("message");
String content = message.optString("content", "");
return content;
} catch (JSONException | URISyntaxException e) {
throw new IOException(e.getMessage());
}
}
public static String[] stringsFromChat(String answer) {
int p = answer.indexOf('[');
int q = answer.indexOf(']');
if (p < 0 || q < 0 || q < p) return new String[0];
try {
JSONArray a = new JSONArray(answer.substring(p, q + 1));
String[] arr = new String[a.length()];
for (int i = 0; i < a.length(); i++) arr[i] = a.getString(i);
return arr;
} catch (JSONException e) {
return new String[0];
}
}
private static String searchWordsForPrompt(String model, String prompt) {
StringBuilder query = new StringBuilder();
String question = "Make a list of a maximum of four search words for the following question; use a JSON Array: " + prompt;
try {
String[] a = stringsFromChat(chat(model, question, 80));
for (String s: a) query.append(s).append(' ');
return query.toString().trim();
} catch (IOException e) {
e.printStackTrace();
return "";
}
}
private static LinkedHashMap searchResults(String query, int count) {
Switchboard sb = Switchboard.getSwitchboard();
EmbeddedSolrConnector connector = sb.index.fulltext().getDefaultEmbeddedConnector();
// construct query
final SolrQuery params = new SolrQuery();
params.setQuery(CollectionSchema.text_t.getSolrFieldName() + ":" + query);
params.setRows(count);
params.setStart(0);
params.setFacet(false);
params.clearSorts();
params.setFields(CollectionSchema.sku.getSolrFieldName(), CollectionSchema.text_t.getSolrFieldName());
params.setIncludeScore(false);
params.set("df", CollectionSchema.text_t.getSolrFieldName());
// query the server
try {
final SolrDocumentList sdl = connector.getDocumentListByParams(params);
LinkedHashMap a = new LinkedHashMap();
Iterator i = sdl.iterator();
while (i.hasNext()) {
SolrDocument doc = i.next();
String url = (String) doc.getFieldValue(CollectionSchema.sku.getSolrFieldName());
String text = (String) doc.getFieldValue(CollectionSchema.text_t.getSolrFieldName());
a.put(url, text);
}
return a;
} catch (SolrException | IOException e) {
return new LinkedHashMap();
}
}
// Ollama client functions
public static LinkedHashMap listOllamaModels() {
LinkedHashMap sortedMap = new LinkedHashMap<>();
try {
String response = sendGetRequest(LLM_API_HOST + "/api/tags");
JSONObject responseObject = new JSONObject(response);
JSONArray models = responseObject.getJSONArray("models");
List> list = new ArrayList<>();
for (int i = 0; i < models.length(); i++) {
JSONObject model = models.getJSONObject(i);
String name = model.optString("name", "");
long size = model.optLong("size", 0);
list.add(new AbstractMap.SimpleEntry(name, size));
}
// Sort the list in descending order based on the values
list.sort((o1, o2) -> o2.getValue().compareTo(o1.getValue()));
// Create a new LinkedHashMap and add the sorted entries
for (Map.Entry entry : list) {
sortedMap.put(entry.getKey(), entry.getValue());
}
} catch (JSONException | URISyntaxException | IOException e) {
e.printStackTrace();
}
return sortedMap;
}
public static boolean ollamaModelExists(String name) {
JSONObject data = new JSONObject();
try {
data.put("name", name);
sendPostRequest(LLM_API_HOST + "/api/show", data);
return true;
} catch (JSONException | URISyntaxException | IOException e) {
return false;
}
}
public static boolean pullOllamaModel(String name) {
JSONObject data = new JSONObject();
try {
data.put("name", name);
data.put("stream", false);
String response = sendPostRequest(LLM_API_HOST + "/api/pull", data);
// this sends {"status": "success"} in case of success
JSONObject responseObject = new JSONObject(response);
String status = responseObject.optString("status", "");
return status.equals("success");
} catch (JSONException | URISyntaxException | IOException e) {
return false;
}
}
public static void main(String[] args) {
LinkedHashMap models = listOllamaModels();
System.out.println(models.toString());
// check if model exists
//String model = "phi3:3.8b";
String model = "gemma:2b";
if (ollamaModelExists(model))
System.out.println("model " + model + " exists");
else
System.out.println("model " + model + " does not exist");
// pull a model
boolean success = pullOllamaModel(model);
System.out.println("pulled model + " + model + ": " + success);
// make chat completion with model
String question = "Who invented the wheel?";
try {
String answer = chat(model, question, 80);
System.out.println(answer);
} catch (IOException e) {
e.printStackTrace();
}
// try the json parser from chat results
question = "Make a list of four names from Star Wars movies. Use a JSON Array.";
try {
String[] a = stringsFromChat(chat(model, question, 80));
for (String s: a) System.out.println(s);
} catch (IOException e) {
e.printStackTrace();
}
}
}