Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into support_limit
Browse files Browse the repository at this point in the history
  • Loading branch information
daikon12 committed Sep 19, 2024
2 parents 243bd60 + 5ba401a commit e63cf88
Show file tree
Hide file tree
Showing 96 changed files with 1,394 additions and 1,101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ public static User getFakeUser() {
return new User(1L, "admin", "admin", "admin@email", 1);
}

public static User getVisitUser() {
return new User(1L, "visit", "visit", "visit@email", 0);
}

public static User getAppUser(int appId) {
String name = String.format("app_%s", appId);
return new User(1L, name, name, "", 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
import com.tencent.supersonic.common.util.S2ThreadContext;
import com.tencent.supersonic.common.util.ThreadContext;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.web.method.HandlerMethod;

import java.lang.reflect.Method;
Expand Down Expand Up @@ -61,7 +60,7 @@ public boolean preHandle(
}

UserWithPassword user = userTokenUtils.getUserWithPassword(request);
if (StringUtils.isNotBlank(user.getName())) {
if (user != null) {
setContext(user.getName(), request);
return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;

import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_CREATE_TIME;
import static com.tencent.supersonic.auth.api.authentication.constant.UserConstants.TOKEN_IS_ADMIN;
Expand Down Expand Up @@ -68,13 +69,13 @@ public String generateAdminToken(HttpServletRequest request) {

public User getUser(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
final Claims claims = getClaims(token, request);
return getUser(claims);
final Optional<Claims> claimsOptional = getClaims(token, request);
return claimsOptional.map(this::getUser).orElse(User.getVisitUser());
}

public User getUser(String token, String appKey) {
final Claims claims = getClaims(token, appKey);
return getUser(claims);
final Optional<Claims> claimsOptional = getClaims(token, appKey);
return claimsOptional.map(this::getUser).orElse(User.getVisitUser());
}

private User getUser(Claims claims) {
Expand All @@ -92,11 +93,13 @@ private User getUser(Claims claims) {
public UserWithPassword getUserWithPassword(HttpServletRequest request) {
String token = request.getHeader(authenticationConfig.getTokenHttpHeaderKey());
if (StringUtils.isBlank(token)) {
String message = "token is blank, get user failed";
log.warn("{}, uri: {}", message, request.getServletPath());
throw new AccessException(message);
return null;
}
final Claims claims = getClaims(token, request);
final Optional<Claims> claimsOptional = getClaims(token, request);
if (!claimsOptional.isPresent()) {
return null;
}
final Claims claims = claimsOptional.get();
Long userId = Long.parseLong(claims.getOrDefault(TOKEN_USER_ID, 0).toString());
String userName = String.valueOf(claims.get(TOKEN_USER_NAME));
String email = String.valueOf(claims.get(TOKEN_USER_EMAIL));
Expand All @@ -109,32 +112,25 @@ public UserWithPassword getUserWithPassword(HttpServletRequest request) {
return UserWithPassword.get(userId, userName, displayName, email, password, isAdmin);
}

private Claims getClaims(String token, HttpServletRequest request) {
Claims claims;
try {
String appKey = getAppKey(request);
claims = getClaims(token, appKey);
} catch (Exception e) {
throw new AccessException("parse user info from token failed :" + token);
}
return claims;
private Optional<Claims> getClaims(String token, HttpServletRequest request) {
String appKey = getAppKey(request);
return getClaims(token, appKey);
}

private Claims getClaims(String token, String appKey) {
Claims claims;
private Optional<Claims> getClaims(String token, String appKey) {
try {
String tokenSecret = getTokenSecret(appKey);
claims =
Claims claims =
Jwts.parser()
.setSigningKey(tokenSecret.getBytes(StandardCharsets.UTF_8))
.build()
.parseClaimsJws(getTokenString(token))
.getBody();
return Optional.of(claims);
} catch (Exception e) {
log.error("getClaims", e);
throw new AccessException("parse user info from token failed :" + token);
log.info("can not getClaims from appKey:{} token:{}, please login", appKey, token);
}
return claims;
return Optional.empty();
}

private static String getTokenString(String token) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ public PluginRecallResult recallPlugin(ParseContext parseContext) {
continue;
}
plugin.setParseMode(ParseMode.EMBEDDING_RECALL);
double distance = embeddingRetrieval.getDistance();
double score = parseContext.getQueryText().length() * (1 - distance);
double similarity = embeddingRetrieval.getSimilarity();
double score = parseContext.getQueryText().length() * similarity;
return PluginRecallResult.builder()
.plugin(plugin)
.dataSetIds(dataSetList)
.score(score)
.distance(distance)
.distance(similarity)
.build();
}
}
Expand All @@ -73,7 +73,9 @@ public List<Retrieval> embeddingRecall(String embeddingText) {
if (!CollectionUtils.isEmpty(embeddingRetrievals)) {
embeddingRetrievals =
embeddingRetrievals.stream()
.sorted(Comparator.comparingDouble(o -> Math.abs(o.getDistance())))
.sorted(
Comparator.comparingDouble(
o -> Math.abs(o.getSimilarity())))
.collect(Collectors.toList());
embeddingResp.setRetrieval(embeddingRetrievals);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ private void fillSimilarMetric(SemanticParseInfo parseInfo) {
List<Retrieval> retrievals =
retrieveQueryResults.stream()
.flatMap(retrieveQueryResult -> retrieveQueryResult.getRetrieval().stream())
.sorted(Comparator.comparingDouble(Retrieval::getDistance))
.sorted(Comparator.comparingDouble(Retrieval::getSimilarity))
.distinct()
.collect(Collectors.toList());
Set<Long> metricIds =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ public Object queryData(ChatQueryDataReq chatQueryDataReq, User user) throws Exc
SemanticParseInfo parseInfo =
chatManageService.getParseInfo(chatQueryDataReq.getQueryId(), parseId);
parseInfo = mergeParseInfo(parseInfo, chatQueryDataReq);
parseInfo.setSqlInfo(new SqlInfo());
DataSetSchema dataSetSchema =
semanticLayerService.getDataSetSchema(parseInfo.getDataSetId());

Expand Down Expand Up @@ -559,8 +560,12 @@ private void validFilter(Set<QueryFilter> filters) {
iterator.remove();
continue;
}
List<String> collection =
JsonUtil.toList(JsonUtil.toString(queryFilterValue), String.class);
List<String> collection = new ArrayList<>();
if (queryFilterValue instanceof List) {
collection.addAll((List) queryFilterValue);
} else if (queryFilterValue instanceof String) {
collection.add((String) queryFilterValue);
}
if (FilterOperatorEnum.IN.equals(queryFilter.getOperator())
&& CollectionUtils.isEmpty(collection)) {
iterator.remove();
Expand Down
34 changes: 8 additions & 26 deletions common/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -186,32 +186,6 @@
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-spring-boot-starter</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-anthropic-spring-boot-starter</artifactId>
<version>${langchain4j.version}</version>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-open-ai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-ai-search-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-azure-open-ai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-embeddings-all-minilm-l6-v2-q</artifactId>
Expand Down Expand Up @@ -244,6 +218,10 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>dev.langchain4j</groupId>
<artifactId>langchain4j-ollama</artifactId>
</dependency>
<!--langchain4j-->
<dependency>
<groupId>com.hankcs</groupId>
Expand All @@ -261,6 +239,10 @@
<artifactId>gson</artifactId>
</dependency>

<dependency>
<groupId>org.codehaus.woodstox</groupId>
<artifactId>stax2-api</artifactId>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,12 @@ public ChatModelConfig convert() {
private static List<String> getCandidateValues() {
return Lists.newArrayList(
OpenAiModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER,
LocalAiModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER);
DashscopeModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}

private static List<Parameter.Dependency> getBaseUrlDependency() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ private static ArrayList<String> getCandidateValues() {
InMemoryModelFactory.PROVIDER,
OpenAiModelFactory.PROVIDER,
OllamaModelFactory.PROVIDER,
AzureModelFactory.PROVIDER,
DashscopeModelFactory.PROVIDER,
QianfanModelFactory.PROVIDER,
ZhipuModelFactory.PROVIDER);
ZhipuModelFactory.PROVIDER,
AzureModelFactory.PROVIDER);
}

private static List<Parameter.Dependency> getBaseUrlDependency() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.tencent.supersonic.common.jsqlparser;

import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
import net.sf.jsqlparser.expression.operators.conditional.OrExpression;
import net.sf.jsqlparser.expression.operators.relational.IsNullExpression;
import net.sf.jsqlparser.schema.Column;
import net.sf.jsqlparser.util.deparser.ExpressionDeParser;

import java.util.Set;

public class CustomExpressionDeParser extends ExpressionDeParser {

private Set<String> removeFieldNames;
private boolean dealNull;
private boolean dealNotNull;

public CustomExpressionDeParser(
Set<String> removeFieldNames, boolean dealNull, boolean dealNotNull) {
this.removeFieldNames = removeFieldNames;
this.dealNull = dealNull;
this.dealNotNull = dealNotNull;
}

@Override
public void visit(AndExpression andExpression) {
processBinaryExpression(andExpression, " AND ");
}

@Override
public void visit(OrExpression orExpression) {
processBinaryExpression(orExpression, " OR ");
}

@Override
public void visit(IsNullExpression isNullExpression) {
if (shouldSkip(isNullExpression)) {
// Skip this expression
} else {
super.visit(isNullExpression);
}
}

private void processBinaryExpression(Expression binaryExpression, String operator) {
Expression leftExpression = ((AndExpression) binaryExpression).getLeftExpression();
Expression rightExpression = ((AndExpression) binaryExpression).getRightExpression();

boolean leftIsNull =
leftExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) leftExpression);
boolean rightIsNull =
rightExpression instanceof IsNullExpression
&& shouldSkip((IsNullExpression) rightExpression);

if (leftIsNull && rightIsNull) {
// Skip both expressions
} else if (leftIsNull) {
rightExpression.accept(this);
} else if (rightIsNull) {
leftExpression.accept(this);
} else {
leftExpression.accept(this);
buffer.append(operator);
rightExpression.accept(this);
}
}

private boolean shouldSkip(IsNullExpression isNullExpression) {
if (isNullExpression.getLeftExpression() instanceof Column) {
Column column = (Column) isNullExpression.getLeftExpression();
String columnName = column.getColumnName();
// Add your target column names here
if (removeFieldNames.contains(columnName)) {
if (isNullExpression.isNot() && dealNotNull) {
return true;
} else if (!isNullExpression.isNot() && dealNull) {
return true;
}
}
}
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import net.sf.jsqlparser.statement.select.SelectItem;
import net.sf.jsqlparser.statement.select.SelectItemVisitorAdapter;
import net.sf.jsqlparser.statement.select.SelectVisitorAdapter;
import net.sf.jsqlparser.util.deparser.ExpressionDeParser;
import net.sf.jsqlparser.util.deparser.SelectDeParser;
import org.springframework.util.CollectionUtils;

import java.util.ArrayList;
Expand Down Expand Up @@ -382,6 +384,42 @@ private static Expression distinguishNumberFilter(
}
}

public static String removeIsNullInWhere(String sql, Set<String> removeFieldNames) {
return removeIsNullOrNotNullInWhere(true, false, sql, removeFieldNames);
}

public static String removeNotNullInWhere(String sql, Set<String> removeFieldNames) {
return removeIsNullOrNotNullInWhere(false, true, sql, removeFieldNames);
}

public static String removeIsNullOrNotNullInWhere(
boolean dealNull, boolean dealNotNull, String sql, Set<String> removeFieldNames) {
Select selectStatement = SqlSelectHelper.getSelect(sql);
if (!(selectStatement instanceof PlainSelect)) {
return sql;
}
// Create a custom ExpressionDeParser to remove specific IS NULL and IS NOT NULL conditions
ExpressionDeParser expressionDeParser =
new CustomExpressionDeParser(removeFieldNames, dealNull, dealNotNull);

StringBuilder buffer = new StringBuilder();
SelectDeParser selectDeParser = new SelectDeParser(expressionDeParser, buffer);
expressionDeParser.setSelectVisitor(selectDeParser);
expressionDeParser.setBuffer(buffer);
PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
if (plainSelect.getWhere() != null) {
plainSelect.getWhere().accept(expressionDeParser);
}
// Parse the modified WHERE clause back to an Expression
try {
Expression newWhere = CCJSqlParserUtil.parseCondExpression(buffer.toString());
plainSelect.setWhere(newWhere);
} catch (Exception e) {
log.error("parseCondExpression error:{}", buffer, e);
}
return selectStatement.toString();
}

private static boolean isInvalidSelect(Select selectStatement) {
return Objects.isNull(selectStatement) || !(selectStatement instanceof PlainSelect);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,6 @@ public interface EmbeddingService {

List<RetrieveQueryResult> retrieveQuery(
String collectionName, RetrieveQuery retrieveQuery, int num);

void removeAll();
}
Loading

0 comments on commit e63cf88

Please sign in to comment.