Mybits plugin& Interceptor & jsqlparse 实现多租户
上一篇文章写道了mybatis框架下自定义拦截器的基本实现。因为项目正好要做多租户的功能,所以我用这张方案实现了一下。
原本平台是基于docker实现多租户方案的,这种方案的优点是省时省力,新建一个租户基本没啥操作,一个dockerfile文件搞定,而且现在的服务商都提供了完整的配套容器服务。但缺点也能明显,每个租户都需要分配一套独立的硬件资源。N个租户意味着要开N个容器,N个数据库,N个tomcat,也有点吃不消。
目前来说,多租户大致分三种方案
- 1.独立数据库
- 2.同一个数据库,不同Schema
- 3.同一个数据库,同一个Schema,每张表使用tenant_id字段区分不同的租户。
个人倾向于第一种(对mysql而言,第一种和第二种没啥区别),直接采用多数据源的方案,对于现有项目改动也不是很大。不过现在要采用第三者方案,所用的租户共享一套数据库。查询了一些资料,orm中hibernat是直接支持多租户方案的,mybatis没有原生支持,但可以通过拦截器实现。
如果采用hibernat,那么现有项目中的mybatis需要全部替换,工作量不较大,后期的测试工作也很麻烦。如果采用mybatis的拦截器改动则比较小,但是mybatis中的xml都是自己写的,拦截后需要自行解析添加tenant_id的过滤。目前来说sql还没找到完美解析和重构的方案,暂时用jsqlparse实现,修改不能解析的复杂sql。
因为拦截器是非侵入式的,所以对项目改动不大,主要也就两点
- 1.在mybatis配置中注册自己的拦截器
- 2.自定义自己的拦截器
注册拦截器
<!-- 用户获取租户id的bean --> <bean id="tenantInfoImpl" class="xxx.xxxx.TenantInfoImpl" />。 <!-- sql拦截过滤器 部分sql查询不需要拦截 --> <bean id="sqlFilterImpl" class="xxx.xxxx.SqlFilterImpl"/> <!-- 配置mybatis --> <bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean"> <property name="dataSource" ref="dataSource"/> <property name="configLocation" value="classpath:mybatis/mybatis-config.xml"/> <!-- mapper扫描 --> <property name="mapperLocations" value="classpath:mybatis/*/*.xml"/> <property name="plugins"> <array> <bean id="tenantInterceptor" class="xxx.xxxx.TenantInterceptor"> <!--数据库中租户ID的列名--> <property name="tenantIdColumn" value="tenantId"/> <!--数据库方言--> <property name="dialect" value="mysql"/> <property name="tenantInfo" ref="tenantInfoImpl"/> <property name="sqlFilter" ref="sqlFilterImpl"/> </bean> </array> </property> </bean>
获取租户id
在拦截之前,我们必须知道这个sql是哪个租户的,如果我们无法识别就该考虑下怎么处理的,因为不是每个查询都是用户的操作引发的。
比如项目中的定时操作,触发事件,还有比较特殊的登入操作,这个时候还不知道谁在使用系统。定时操作之类的需要在定义的时候就写入租户信息,自行完成租户的隔离。
如果处于用户登入状态,那么可以直接通过session或者cookie读取用户所绑定的租户信息。
寻找注入点
我们需要明确,在mybatis整个查询过程中,需要在哪个类的哪个方法上执行自定义拦截器以便获取sql并修改。
MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:
// 前面是允许用插件拦截的类名,括号里是允许用插件拦截的方法名 Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed) ParameterHandler (getParameterObject, setParameters) ResultSetHandler (handleResultSets, handleOutputParameters) StatementHandler (prepare, parameterize, batch, update, query)
MyBatis是在StatementHandler中的prepare(...)方法中完成对sql的解析,所以我们需要在这个方法前设置一个拦截器也就是plugin来进行sql语句的置换.
但测试发现,resultmap中定义的collection包含的子查询无法被拦截。
经查证,在StatementHandler的接口定义中包含了如下方法定义。
Statement prepare(Connection connection) throws SQLException; void parameterize(Statement statement) throws SQLException; void batch(Statement statement) throws SQLException; int update(Statement statement) throws SQLException;List query(Statement statement, ResultHandler resultHandler) throws SQLException; BoundSql getBoundSql(); ParameterHandler getParameterHandler();
collection标签中的子查询调用的是query方法,我们需要在这个时候处理collection中的子查询,当然普通的查询也会进入update,不过此时再修改sql是不生效的。
最终的注解声明
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class}), @Signature(method = "query", type = StatementHandler.class, args = {java.sql.Statement.class, ResultHandler.class})})
分辨需要添加tenant_id条件的sql
mybaitis中每个sql定义都会有一个id用于区分,如果sql定义在xml中,那么完整的id就是namespace加上sql本身的id,如果是定义在接口类的注解上的,那么就是接口的完全限定接口名加方法名。
读取和解析sql
手动实现SQL的解析是万万不能的,简单的select where join还行,但加上子查询 排序 分组 各种函数 类型转换。。。。就hold不住了。github上找到了一个转换器(github.com/JSQLParser/JSqlParser)
经测试,项目里95%的sql都能解析,剩下的只能自己改了。基本都是带了非常偏门的语法。
StatementHandler handler = (StatementHandler) invocation.getTarget(); //由于mappedStatement为protected的,所以要通过反射获取 MetaObject statementHandler = SystemMetaObject.forObject(handler); //mappedStatement中有我们需要的方法id MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement"); BoundSql boundSql = handler.getBoundSql(); String sql = boundSql.getSql(); String id = mappedStatement.getId(); Statement stmt = CCJSqlParserUtil.parse(sql);
修改sql
JSqlParser会把sql字符串转化成基本的java对象,你需要重顶层遍历找到需要修改的地方添加where条件,百度上也有些例子,但都不完整。
private String addWhere(String sql) throws Exception { Statement stmt = CCJSqlParserUtil.parse(sql); if (stmt instanceof Insert) { //获得Update对象 Insert insert = (Insert) stmt; insert.getColumns().add(new Column(getTenantIdColumn())); if (insert.getItemsList() instanceof MultiExpressionList){ for (ExpressionList expressionList : ((MultiExpressionList) insert.getItemsList()).getExprList()) { addTenantValue(expressionList); } }else { addTenantValue(((ExpressionList) insert.getItemsList())); } return insert.toString(); } if (stmt instanceof Delete) { //获得Delete对象 Delete deleteStatement = (Delete) stmt; Expression where = deleteStatement.getWhere(); if (where instanceof BinaryExpression) { EqualsTo equalsTo = new EqualsTo(); equalsTo.setLeftExpression(new Column(getTenantIdColumn())); equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId())); AndExpression andExpression = new AndExpression(equalsTo, where); deleteStatement.setWhere(andExpression); } return deleteStatement.toString(); } if (stmt instanceof Update) { //获得Update对象 Update updateStatement = (Update) stmt; //获得where条件表达式 Expression where = updateStatement.getWhere(); /* if (where instanceof BinaryExpression) { // 针对是否含where条件做不同处理 if (updateStatement.getWhere() != null) { updateStatement.setWhere(addAndExpression(stmt, getTenantIdColumn(), updateStatement.getWhere())); } else { EqualsTo equalsTo = new EqualsTo(); equalsTo.setLeftExpression(new Column(getTenantIdColumn())); equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId())); updateStatement.setWhere(equalsTo); } }*/ TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); ListtableList = tablesNamesFinder.getTableList(stmt); // 针对数据库连接测试没有表名的情况处理【select 'x';select 1什么的】 if (tableList.size() == 0) { return updateStatement.toString(); } for (String table : tableList) { if (updateStatement.getWhere() != null) { updateStatement.setWhere(addAndExpression(stmt, table, updateStatement.getWhere())); } else { throw new Exception("update语句不能没有where条件:" + sql + Arrays.toString(Thread.currentThread().getStackTrace())); } } return updateStatement.toString(); } if (stmt instanceof Select) { Select select = (Select) stmt; PlainSelect ps = (PlainSelect) select.getSelectBody(); TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); List tableList = tablesNamesFinder.getTableList(select); if (tableList.size() == 0) { return select.toString(); } for (String table : tableList) { if (ps.getWhere() != null) { AndExpression where = addAndExpression(stmt, table, ps.getWhere()); // form 和 join 中加载的表 if (where != null) { ps.setWhere(where); } else { //子查询中的表 findSubSelect(stmt, ps.getWhere()); } } else { ps.setWhere(addEqualsTo(stmt, table)); } } return select.toString(); } if (stmt instanceof CreateTable) { CreateTable createTable = (CreateTable) stmt; ColumnDefinition columnDefinition = new ColumnDefinition(); columnDefinition.setColumnName("tenantId"); ColDataType colDataType = new ColDataType(); colDataType.setDataType("varchar(255)"); columnDefinition.setColDataType(colDataType); createTable.getColumnDefinitions().add(columnDefinition); return createTable.toString(); } return null; } /** * 插入数据 添加租户id * @param expressionList * @throws Exception */ private void addTenantValue(ExpressionList expressionList) throws Exception { expressionList.getExpressions().add(new StringValue(getTenantInfo().getTenantId())); } /** * 多条件情况下,使用AndExpression给where条件加上tenantid条件 * * @param table * @param where * @return * @throws Exception */ public AndExpression addAndExpression(Statement stmt, String table, Expression where) throws Exception { EqualsTo equalsTo = addEqualsTo(stmt, table); if (equalsTo != null) { return new AndExpression(equalsTo, where); } else { return null; } } /** * 创建一个 EqualsTo相同判断 条件 * * @param stmt 查询对象 * @param table 表名 * @return “A=B” 单个where条件表达式 * @throws Exception */ public EqualsTo addEqualsTo(Statement stmt, String table) throws Exception { EqualsTo equalsTo = new EqualsTo(); String aliasName; aliasName = getTableAlias(stmt, table); if (aliasName != null) { equalsTo.setLeftExpression(new Column(aliasName + '.' + getTenantIdColumn())); equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId())); return equalsTo; } else { return null; } } /** * 获取sql送指定表的别名你,没有别名则返回原表名 如果表名不存在返回null * 【仅查询from和join 不含 IN 子查询中的表 】 * * @param stmt * @param tableName * @return */ public String getTableAlias(Statement stmt, String tableName) { String alias = null; // 插入不做处理 if (stmt instanceof Insert) { return tableName; } if (stmt instanceof Delete) { //获得Delete对象 Delete deleteStatement = (Delete) stmt; if ((deleteStatement.getTable()).getName().equalsIgnoreCase(tableName)) { alias = deleteStatement.getTable().getAlias() != null ? deleteStatement.getTable().getAlias().getName() : tableName; } } if (stmt instanceof Update) { //获得Update对象 Update updateStatement = (Update) stmt; if ((updateStatement.getTables().get(0)).getName().equalsIgnoreCase(tableName)) { alias = updateStatement.getTables().get(0).getAlias() != null ? updateStatement.getTables().get(0).getAlias().getName() : tableName; } } if (stmt instanceof Select) { Select select = (Select) stmt; PlainSelect ps = (PlainSelect) select.getSelectBody(); // 判断主表的别名 if (((Table) ps.getFromItem()).getName().equalsIgnoreCase(tableName)) { alias = ps.getFromItem().getAlias() != null ? ps.getFromItem().getAlias().getName() : tableName; } } return alias; } /** * 针对子查询中的表别名查询 * * @param subSelect * @param tableName * @return */ public String getTableAlias(SubSelect subSelect, String tableName) { PlainSelect ps = (PlainSelect) subSelect.getSelectBody(); // 判断主表的别名 String alias = null; if (((Table) ps.getFromItem()).getName().equalsIgnoreCase(tableName)) { if (ps.getFromItem().getAlias() != null) { alias = ps.getFromItem().getAlias().getName(); } else { alias = tableName; } } return alias; } /** * 递归处理 子查询中的tenantid-where * * @param stmt sql查询对象 * @param where 当前sql的where条件 where为AndExpression或OrExpression的实例,解析其中的rightExpression,然后检查leftExpression是否为空, * 不为空则是AndExpression或OrExpression,再次解析其中的rightExpression * 注意tenantid-where是加在子查询上的 */ void findSubSelect(Statement stmt, Expression where) throws Exception { // and 表达式 if (where instanceof AndExpression) { AndExpression andExpression = (AndExpression) where; if (andExpression.getRightExpression() instanceof SubSelect) { SubSelect subSelect = (SubSelect) andExpression.getRightExpression(); doSelect(stmt, subSelect); } if (andExpression.getLeftExpression() != null) { findSubSelect(stmt, andExpression.getLeftExpression()); } } else if (where instanceof OrExpression) { // or表达式 OrExpression orExpression = (OrExpression) where; if (orExpression.getRightExpression() instanceof SubSelect) { SubSelect subSelect = (SubSelect) orExpression.getRightExpression(); doSelect(stmt, subSelect); } if (orExpression.getLeftExpression() != null) { findSubSelect(stmt, orExpression.getLeftExpression()); } } } /** * 处理select 和 subSelect * * @param stmt 查询对象 * @param select * @return * @throws Exception */ Expression doSelect(Statement stmt, Expression select) throws Exception { PlainSelect ps = null; boolean hasSubSelect = false; if (select instanceof SubSelect) { ps = (PlainSelect) ((SubSelect) select).getSelectBody(); } if (select instanceof Select) { ps = (PlainSelect) ((Select) select).getSelectBody(); } TablesNamesFinder tablesNamesFinder = new TablesNamesFinder(); List tableList = tablesNamesFinder.getTableList(select); if (tableList.size() == 0) { return select; } for (String table : tableList) { // sql 包含 where 条件的情况 使用 addAndExpression 连接 已有的条件和新条件 if (ps.getWhere() == null) { AndExpression where = addAndExpression(stmt, table, ps.getWhere()); // form 和 join 中加载的表 if (where != null) { ps.setWhere(where); } else { // 如果在Statement中不存在这个表名,则存在于子查询中 hasSubSelect = true; } } else { // sql 不含 where条件 新建一个EqualsTo设置为where条件 EqualsTo equalsTo = addEqualsTo(stmt, table); ps.setWhere(equalsTo); } } if (hasSubSelect) { //子查询中的表 findSubSelect(stmt, ps.getWhere()); } return select; }
相关代码打包下载
code
PS:
这块解析和处理租户id隔离的代码已经被我重构了,基本内容是一样了,只是进行了必要的拆分和扩充
目前似乎还没处理where条件的值是子查询的情况,比如 where id in (select id from user),in 中的子查询没添加tenant_id的限制条件。
代码打包:tenant