Mybits plugin& Interceptor & jsqlparse 实现多租户

上一篇文章写道了mybatis框架下自定义拦截器的基本实现。因为项目正好要做多租户的功能,所以我用这张方案实现了一下。

原本平台是基于docker实现多租户方案的,这种方案的优点是省时省力,新建一个租户基本没啥操作,一个dockerfile文件搞定,而且现在的服务商都提供了完整的配套容器服务。但缺点也能明显,每个租户都需要分配一套独立的硬件资源。N个租户意味着要开N个容器,N个数据库,N个tomcat,也有点吃不消。


目前来说,多租户大致分三种方案

  • 1.独立数据库
  • 2.同一个数据库,不同Schema
  • 3.同一个数据库,同一个Schema,每张表使用tenant_id字段区分不同的租户。

个人倾向于第一种(对mysql而言,第一种和第二种没啥区别),直接采用多数据源的方案,对于现有项目改动也不是很大。不过现在要采用第三者方案,所用的租户共享一套数据库。查询了一些资料,orm中hibernat是直接支持多租户方案的,mybatis没有原生支持,但可以通过拦截器实现。

如果采用hibernat,那么现有项目中的mybatis需要全部替换,工作量不较大,后期的测试工作也很麻烦。如果采用mybatis的拦截器改动则比较小,但是mybatis中的xml都是自己写的,拦截后需要自行解析添加tenant_id的过滤。目前来说sql还没找到完美解析和重构的方案,暂时用jsqlparse实现,修改不能解析的复杂sql。


因为拦截器是非侵入式的,所以对项目改动不大,主要也就两点

  • 1.在mybatis配置中注册自己的拦截器
  • 2.自定义自己的拦截器

注册拦截器

    <!-- 用户获取租户id的bean -->
     <bean id="tenantInfoImpl" class="xxx.xxxx.TenantInfoImpl" />。
     <!-- sql拦截过滤器 部分sql查询不需要拦截 -->
     <bean id="sqlFilterImpl" class="xxx.xxxx.SqlFilterImpl"/>

     <!-- 配置mybatis -->
     <bean id="sqlSessionFactory" class="org.mybatis.spring.SqlSessionFactoryBean">
         <property name="dataSource" ref="dataSource"/>
         <property name="configLocation" value="classpath:mybatis/mybatis-config.xml"/>
         <!-- mapper扫描 -->
         <property name="mapperLocations" value="classpath:mybatis/*/*.xml"/>
         <property name="plugins">
             <array>
                 <bean id="tenantInterceptor" class="xxx.xxxx.TenantInterceptor">
                     <!--数据库中租户ID的列名-->
                     <property name="tenantIdColumn" value="tenantId"/>
                     <!--数据库方言-->
                     <property name="dialect" value="mysql"/>
                     <property name="tenantInfo" ref="tenantInfoImpl"/>
                     <property name="sqlFilter" ref="sqlFilterImpl"/>
                 </bean>
             </array>
         </property>
     </bean>

获取租户id

在拦截之前,我们必须知道这个sql是哪个租户的,如果我们无法识别就该考虑下怎么处理的,因为不是每个查询都是用户的操作引发的。
比如项目中的定时操作,触发事件,还有比较特殊的登入操作,这个时候还不知道谁在使用系统。定时操作之类的需要在定义的时候就写入租户信息,自行完成租户的隔离。
如果处于用户登入状态,那么可以直接通过session或者cookie读取用户所绑定的租户信息。



寻找注入点

我们需要明确,在mybatis整个查询过程中,需要在哪个类的哪个方法上执行自定义拦截器以便获取sql并修改。
MyBatis 允许你在已映射语句执行过程中的某一点进行拦截调用。默认情况下,MyBatis 允许使用插件来拦截的方法调用包括:

// 前面是允许用插件拦截的类名,括号里是允许用插件拦截的方法名
Executor (update, query, flushStatements, commit, rollback, getTransaction, close, isClosed)
ParameterHandler (getParameterObject, setParameters)
ResultSetHandler (handleResultSets, handleOutputParameters)
StatementHandler (prepare, parameterize, batch, update, query)

MyBatis是在StatementHandler中的prepare(...)方法中完成对sql的解析,所以我们需要在这个方法前设置一个拦截器也就是plugin来进行sql语句的置换.
但测试发现,resultmap中定义的collection包含的子查询无法被拦截。
经查证,在StatementHandler的接口定义中包含了如下方法定义。

Statement prepare(Connection connection)
      throws SQLException;

  void parameterize(Statement statement)
      throws SQLException;

  void batch(Statement statement)
      throws SQLException;

  int update(Statement statement)
      throws SQLException;

   List query(Statement statement, ResultHandler resultHandler)
      throws SQLException;

  BoundSql getBoundSql();

  ParameterHandler getParameterHandler();

collection标签中的子查询调用的是query方法,我们需要在这个时候处理collection中的子查询,当然普通的查询也会进入update,不过此时再修改sql是不生效的。

最终的注解声明

@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class}),
        @Signature(method = "query", type = StatementHandler.class, args = {java.sql.Statement.class, ResultHandler.class})})

分辨需要添加tenant_id条件的sql

mybaitis中每个sql定义都会有一个id用于区分,如果sql定义在xml中,那么完整的id就是namespace加上sql本身的id,如果是定义在接口类的注解上的,那么就是接口的完全限定接口名加方法名。

读取和解析sql

手动实现SQL的解析是万万不能的,简单的select where join还行,但加上子查询 排序 分组 各种函数 类型转换。。。。就hold不住了。github上找到了一个转换器(github.com/JSQLParser/JSqlParser)
经测试,项目里95%的sql都能解析,剩下的只能自己改了。基本都是带了非常偏门的语法。

StatementHandler handler = (StatementHandler) invocation.getTarget();
//由于mappedStatement为protected的,所以要通过反射获取
MetaObject statementHandler = SystemMetaObject.forObject(handler);
//mappedStatement中有我们需要的方法id
MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
BoundSql boundSql = handler.getBoundSql();
String sql = boundSql.getSql();
String id = mappedStatement.getId();
Statement stmt = CCJSqlParserUtil.parse(sql);
 

修改sql

JSqlParser会把sql字符串转化成基本的java对象,你需要重顶层遍历找到需要修改的地方添加where条件,百度上也有些例子,但都不完整。

    private String addWhere(String sql) throws Exception {
        Statement stmt = CCJSqlParserUtil.parse(sql);
        if (stmt instanceof Insert) {
            //获得Update对象
            Insert insert = (Insert) stmt;
            insert.getColumns().add(new Column(getTenantIdColumn()));

            if (insert.getItemsList() instanceof MultiExpressionList){
                for (ExpressionList expressionList : ((MultiExpressionList) insert.getItemsList()).getExprList()) {
                    addTenantValue(expressionList);
                }
            }else {
                addTenantValue(((ExpressionList) insert.getItemsList()));
            }
            return insert.toString();
        }

        if (stmt instanceof Delete) {
            //获得Delete对象
            Delete deleteStatement = (Delete) stmt;
            Expression where = deleteStatement.getWhere();
            if (where instanceof BinaryExpression) {
                EqualsTo equalsTo = new EqualsTo();
                equalsTo.setLeftExpression(new Column(getTenantIdColumn()));
                equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId()));
                AndExpression andExpression = new AndExpression(equalsTo, where);
                deleteStatement.setWhere(andExpression);
            }
            return deleteStatement.toString();
        }

        if (stmt instanceof Update) {
            //获得Update对象
            Update updateStatement = (Update) stmt;
            //获得where条件表达式
            Expression where = updateStatement.getWhere();
         /*   if (where instanceof BinaryExpression) {
                // 针对是否含where条件做不同处理
                if (updateStatement.getWhere() != null) {
                    updateStatement.setWhere(addAndExpression(stmt, getTenantIdColumn(), updateStatement.getWhere()));
                } else {
                    EqualsTo equalsTo = new EqualsTo();
                    equalsTo.setLeftExpression(new Column(getTenantIdColumn()));
                    equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId()));
                    updateStatement.setWhere(equalsTo);
                }
            }*/
            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
            List tableList = tablesNamesFinder.getTableList(stmt);
            // 针对数据库连接测试没有表名的情况处理【select 'x';select 1什么的】
            if (tableList.size() == 0) {
                return updateStatement.toString();
            }
            for (String table : tableList) {
                if (updateStatement.getWhere() != null) {
                    updateStatement.setWhere(addAndExpression(stmt, table, updateStatement.getWhere()));
                } else {
                    throw new Exception("update语句不能没有where条件:" + sql + Arrays.toString(Thread.currentThread().getStackTrace()));
                }
            }
            return updateStatement.toString();
        }

        if (stmt instanceof Select) {
            Select select = (Select) stmt;
            PlainSelect ps = (PlainSelect) select.getSelectBody();
            TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
            List tableList = tablesNamesFinder.getTableList(select);
            if (tableList.size() == 0) {
                return select.toString();
            }
            for (String table : tableList) {
                if (ps.getWhere() != null) {
                    AndExpression where = addAndExpression(stmt, table, ps.getWhere());
                    // form 和 join 中加载的表
                    if (where != null) {
                        ps.setWhere(where);
                    } else {
                        //子查询中的表
                        findSubSelect(stmt, ps.getWhere());
                    }
                } else {
                    ps.setWhere(addEqualsTo(stmt, table));
                }
            }
            return select.toString();
        }

        if (stmt instanceof CreateTable) {
            CreateTable createTable = (CreateTable) stmt;

            ColumnDefinition columnDefinition = new ColumnDefinition();
            columnDefinition.setColumnName("tenantId");
            ColDataType colDataType = new ColDataType();
            colDataType.setDataType("varchar(255)");
            columnDefinition.setColDataType(colDataType);
            createTable.getColumnDefinitions().add(columnDefinition);
            return createTable.toString();
        }

        return null;
    }

    /**
     * 插入数据 添加租户id
     * @param expressionList
     * @throws Exception
     */
    private void addTenantValue(ExpressionList expressionList) throws Exception {
        expressionList.getExpressions().add(new StringValue(getTenantInfo().getTenantId()));
    }

    /**
     * 多条件情况下,使用AndExpression给where条件加上tenantid条件
     *
     * @param table
     * @param where
     * @return
     * @throws Exception
     */
    public AndExpression addAndExpression(Statement stmt, String table, Expression where) throws Exception {
        EqualsTo equalsTo = addEqualsTo(stmt, table);
        if (equalsTo != null) {
            return new AndExpression(equalsTo, where);
        } else {
            return null;
        }
    }

    /**
     * 创建一个 EqualsTo相同判断 条件
     *
     * @param stmt  查询对象
     * @param table 表名
     * @return “A=B” 单个where条件表达式
     * @throws Exception
     */
    public EqualsTo addEqualsTo(Statement stmt, String table) throws Exception {
        EqualsTo equalsTo = new EqualsTo();
        String aliasName;
        aliasName = getTableAlias(stmt, table);
        if (aliasName != null) {
            equalsTo.setLeftExpression(new Column(aliasName + '.' + getTenantIdColumn()));
            equalsTo.setRightExpression(new StringValue(getTenantInfo().getTenantId()));
            return equalsTo;
        } else {
            return null;
        }
    }
    /**
     * 获取sql送指定表的别名你,没有别名则返回原表名 如果表名不存在返回null
     * 【仅查询from和join 不含 IN 子查询中的表 】
     *
     * @param stmt
     * @param tableName
     * @return
     */
    public String getTableAlias(Statement stmt, String tableName) {
        String alias = null;

        // 插入不做处理
        if (stmt instanceof Insert) {
            return tableName;
        }

        if (stmt instanceof Delete) {
            //获得Delete对象
            Delete deleteStatement = (Delete) stmt;

            if ((deleteStatement.getTable()).getName().equalsIgnoreCase(tableName)) {
                alias = deleteStatement.getTable().getAlias() != null ? deleteStatement.getTable().getAlias().getName() : tableName;
            }
        }

        if (stmt instanceof Update) {
            //获得Update对象
            Update updateStatement = (Update) stmt;

            if ((updateStatement.getTables().get(0)).getName().equalsIgnoreCase(tableName)) {
                alias = updateStatement.getTables().get(0).getAlias() != null ? updateStatement.getTables().get(0).getAlias().getName() : tableName;
            }
        }

        if (stmt instanceof Select) {
            Select select = (Select) stmt;

            PlainSelect ps = (PlainSelect) select.getSelectBody();

            // 判断主表的别名
            if (((Table) ps.getFromItem()).getName().equalsIgnoreCase(tableName)) {
                alias = ps.getFromItem().getAlias() != null ? ps.getFromItem().getAlias().getName() : tableName;
            }
        }
        return alias;
    }

    /**
     * 针对子查询中的表别名查询
     *
     * @param subSelect
     * @param tableName
     * @return
     */
    public String getTableAlias(SubSelect subSelect, String tableName) {
        PlainSelect ps = (PlainSelect) subSelect.getSelectBody();
        // 判断主表的别名
        String alias = null;
        if (((Table) ps.getFromItem()).getName().equalsIgnoreCase(tableName)) {
            if (ps.getFromItem().getAlias() != null) {
                alias = ps.getFromItem().getAlias().getName();
            } else {
                alias = tableName;
            }
        }
        return alias;
    }

    /**
     * 递归处理 子查询中的tenantid-where
     *
     * @param stmt  sql查询对象
     * @param where 当前sql的where条件 where为AndExpression或OrExpression的实例,解析其中的rightExpression,然后检查leftExpression是否为空,
     *              不为空则是AndExpression或OrExpression,再次解析其中的rightExpression
     *              注意tenantid-where是加在子查询上的
     */
    void findSubSelect(Statement stmt, Expression where) throws Exception {

        // and 表达式
        if (where instanceof AndExpression) {
            AndExpression andExpression = (AndExpression) where;
            if (andExpression.getRightExpression() instanceof SubSelect) {
                SubSelect subSelect = (SubSelect) andExpression.getRightExpression();
                doSelect(stmt, subSelect);
            }
            if (andExpression.getLeftExpression() != null) {
                findSubSelect(stmt, andExpression.getLeftExpression());
            }
        } else if (where instanceof OrExpression) {
            //  or表达式
            OrExpression orExpression = (OrExpression) where;
            if (orExpression.getRightExpression() instanceof SubSelect) {
                SubSelect subSelect = (SubSelect) orExpression.getRightExpression();
                doSelect(stmt, subSelect);
            }
            if (orExpression.getLeftExpression() != null) {
                findSubSelect(stmt, orExpression.getLeftExpression());
            }
        }
    }

    /**
     * 处理select 和 subSelect
     *
     * @param stmt   查询对象
     * @param select
     * @return
     * @throws Exception
     */
    Expression doSelect(Statement stmt, Expression select) throws Exception {
        PlainSelect ps = null;
        boolean hasSubSelect = false;

        if (select instanceof SubSelect) {
            ps = (PlainSelect) ((SubSelect) select).getSelectBody();
        }
        if (select instanceof Select) {
            ps = (PlainSelect) ((Select) select).getSelectBody();
        }

        TablesNamesFinder tablesNamesFinder = new TablesNamesFinder();
        List tableList = tablesNamesFinder.getTableList(select);
        if (tableList.size() == 0) {
            return select;
        }
        for (String table : tableList) {
            // sql 包含 where 条件的情况 使用 addAndExpression 连接 已有的条件和新条件
            if (ps.getWhere() == null) {
                AndExpression where = addAndExpression(stmt, table, ps.getWhere());
                // form 和 join 中加载的表
                if (where != null) {
                    ps.setWhere(where);
                } else {
                    // 如果在Statement中不存在这个表名,则存在于子查询中
                    hasSubSelect = true;
                }
            } else {
                // sql 不含 where条件 新建一个EqualsTo设置为where条件
                EqualsTo equalsTo = addEqualsTo(stmt, table);
                ps.setWhere(equalsTo);
            }
        }

        if (hasSubSelect) {
            //子查询中的表
            findSubSelect(stmt, ps.getWhere());
        }
        return select;
    }


相关代码打包下载
code

PS:
这块解析和处理租户id隔离的代码已经被我重构了,基本内容是一样了,只是进行了必要的拆分和扩充
目前似乎还没处理where条件的值是子查询的情况,比如 where id in (select id from user),in 中的子查询没添加tenant_id的限制条件。
代码打包:tenant